Skip to content

Commit 472f394

Browse files
Adrian WälchlikuynzerebBordajeremyjordan
authored
Resolve some codefactor issues (#756)
* remove unnecessary pass statements * use isinstance for type checks * remove unnecessary else/elif after return * remove unnecessary return statements * move doc string to top * merge isinstance calls * remove unnecessary else/elif after raise * use list comprehension * do not use len without comparison * add missing shebang * revert isinstance check back to type broke tests, because bool is actually subclass of int * add missing period to doc string * remove unnecessary pass statements * use isinstance for type checks * remove unnecessary else/elif after return * remove unnecessary return statements * move doc string to top * merge isinstance calls * remove unnecessary else/elif after raise * use list comprehension * do not use len without comparison * add missing shebang * revert isinstance check back to type broke tests, because bool is actually subclass of int * add missing period to doc string * Fix default ckpt path when logger exists (#771) * rename logging -> loggers (#767) * move logging >> loggers * add warning * fix tests * logging alias * formatting * formatting * use isinstance for type checks * revert isinstance check back to type broke tests, because bool is actually subclass of int * add more detail to tbptt example (#755) * add more detail to tbptt example * warn user about new arg in training_step Co-authored-by: Vadim Bereznyuk <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jeremy Jordan <[email protected]>
1 parent 5e97e66 commit 472f394

15 files changed

+46
-81
lines changed

.run_local_tests.sh

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#!/usr/bin/env bash
2+
13
# use this to run tests
24
rm -rf _ckpt_*
35
rm -rf tests/save_dir*

pytorch_lightning/callbacks/pt_callbacks.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def set_params(self, params):
2727
self.params = params
2828

2929
def set_model(self, model):
30-
if type(model) is LightningDistributedDataParallel:
30+
if isinstance(model, LightningDistributedDataParallel):
3131
model = model.module
3232
self.model = model
3333

@@ -43,7 +43,6 @@ def on_epoch_begin(self, epoch, logs=None):
4343
4444
on_epoch_begin(epoch=2, logs={'val_loss': 0.2})
4545
"""
46-
pass
4746

4847
def on_epoch_end(self, epoch, logs=None):
4948
pass
@@ -56,7 +55,6 @@ def on_batch_begin(self, batch, logs=None):
5655
batch (Tensor): current batch tensor
5756
logs (dict): key-value pairs of quantities to monitor
5857
"""
59-
pass
6058

6159
def on_batch_end(self, batch, logs=None):
6260
pass
@@ -143,7 +141,7 @@ def check_metrics(self, logs):
143141
if monitor_val is None:
144142
if self.strict:
145143
raise RuntimeError(error_msg)
146-
elif self.verbose > 0:
144+
if self.verbose > 0:
147145
warnings.warn(error_msg, RuntimeWarning)
148146

149147
return False
@@ -399,7 +397,7 @@ def __init__(self, scheduling: dict):
399397
if minimal_epoch < 1:
400398
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
401399
raise IndexError(msg)
402-
elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor
400+
if minimal_epoch != 1: # if user didnt define first epoch accumulation factor
403401
scheduling.update({1: 1})
404402

405403
self.scheduling = scheduling

pytorch_lightning/core/hooks.py

-11
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,19 @@ def on_sanity_check_start(self):
3232
.. warning:: will be deprecated.
3333
:return:
3434
"""
35-
pass
3635

3736
def on_train_start(self):
3837
"""Called at the beginning of training before sanity check
3938
:return:
4039
"""
4140
# do something at the start of training
42-
pass
4341

4442
def on_train_end(self):
4543
"""
4644
Called at the end of training before logger experiment is closed
4745
:return:
4846
"""
4947
# do something at the end of training
50-
pass
5148

5249
def on_batch_start(self, batch):
5350
"""Called in the training loop before anything happens for that batch.
@@ -56,32 +53,26 @@ def on_batch_start(self, batch):
5653
:return:
5754
"""
5855
# do something when the batch starts
59-
pass
6056

6157
def on_batch_end(self):
6258
"""Called in the training loop after the batch."""
6359
# do something when the batch ends
64-
pass
6560

6661
def on_epoch_start(self):
6762
"""Called in the training loop at the very beginning of the epoch."""
6863
# do something when the epoch starts
69-
pass
7064

7165
def on_epoch_end(self):
7266
"""Called in the training loop at the very end of the epoch."""
7367
# do something when the epoch ends
74-
pass
7568

7669
def on_pre_performance_check(self):
7770
"""Called at the very beginning of the validation loop."""
7871
# do something before validation starts
79-
pass
8072

8173
def on_post_performance_check(self):
8274
"""Called at the very end of the validation loop."""
8375
# do something before validation end
84-
pass
8576

8677
def on_before_zero_grad(self, optimizer):
8778
"""Called after optimizer.step() and before optimizer.zero_grad()
@@ -99,7 +90,6 @@ def on_before_zero_grad(self, optimizer):
9990
:return:
10091
"""
10192
# do something with the optimizer or inspect it.
102-
pass
10393

10494
def on_after_backward(self):
10595
"""Called after loss.backward() and before optimizers do anything.
@@ -122,7 +112,6 @@ def on_after_backward(self):
122112
global_step=self.trainer.global_step)
123113
124114
"""
125-
pass
126115

127116
def backward(self, use_amp, loss, optimizer, optimizer_idx):
128117
"""Override backward with your own implementation if you need to

pytorch_lightning/core/lightning.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ def training_step(self, batch, batch_idx, hiddens):
253253
You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to
254254
break out of the current training epoch early.
255255
"""
256-
pass
257256

258257
def validation_step(self, *args, **kwargs):
259258
r"""
@@ -326,7 +325,6 @@ def validation_step(self, batch, batch_idx, dataset_idx):
326325
.. note:: When the validation_step is called, the model has been put in eval mode and PyTorch gradients
327326
have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.
328327
"""
329-
pass
330328

331329
def test_step(self, *args, **kwargs):
332330
"""return whatever outputs will need to be aggregated in test_end
@@ -395,7 +393,6 @@ def test_step(self, batch, batch_idx, dataset_idx):
395393
396394
The `dataset_idx` corresponds to the order of datasets returned in `test_dataloader`.
397395
"""
398-
pass
399396

400397
def validation_end(self, outputs):
401398
"""Outputs has the appended output after each validation step.
@@ -467,7 +464,6 @@ def validation_end(self, outputs):
467464
return results
468465
469466
"""
470-
pass
471467

472468
def test_end(self, outputs):
473469
"""Outputs has the appended output after each test step.
@@ -532,7 +528,6 @@ def test_end(self, outputs):
532528
return results
533529
534530
"""
535-
pass
536531

537532
def configure_ddp(self, model, device_ids):
538533
r"""
@@ -842,8 +837,7 @@ def tbptt_split_batch(self, batch, split_size):
842837
Each returned batch split is passed separately to training_step(...).
843838
844839
"""
845-
time_dims = [len(x[0]) for x in batch if isinstance(
846-
x, torch.Tensor) or isinstance(x, collections.Sequence)]
840+
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
847841
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
848842
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"
849843

@@ -1192,7 +1186,6 @@ def on_load_checkpoint(self, checkpoint):
11921186
.. note:: Lighting auto-restores global step, epoch, and all training state including amp scaling.
11931187
No need for you to restore anything regarding training.
11941188
"""
1195-
pass
11961189

11971190
def on_save_checkpoint(self, checkpoint):
11981191
r"""
@@ -1216,7 +1209,6 @@ def on_save_checkpoint(self, checkpoint):
12161209
for you to store anything about training.
12171210
12181211
"""
1219-
pass
12201212

12211213

12221214
def load_hparams_from_tags_csv(tags_csv):
@@ -1236,7 +1228,7 @@ def load_hparams_from_tags_csv(tags_csv):
12361228
def convert(val):
12371229
constructors = [int, float, str]
12381230

1239-
if type(val) is str:
1231+
if isinstance(val, str):
12401232
if val.lower() == 'true':
12411233
return True
12421234
if val.lower() == 'false':

pytorch_lightning/core/memory.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_variable_sizes(self):
7777
if isinstance(input_, (list, tuple)): # pragma: no cover
7878
in_size = []
7979
for x in input_:
80-
if type(x) is list:
80+
if isinstance(x, list):
8181
in_size.append(len(x))
8282
else:
8383
in_size.append(x.size())
@@ -97,7 +97,6 @@ def get_variable_sizes(self):
9797
self.in_sizes = in_sizes
9898
self.out_sizes = out_sizes
9999
assert len(in_sizes) == len(out_sizes)
100-
return
101100

102101
def get_layer_names(self):
103102
'''Collect Layer Names'''
@@ -112,21 +111,17 @@ def get_layer_names(self):
112111

113112
self.layer_names = names
114113
self.layer_types = layer_types
115-
return
116114

117115
def get_parameter_sizes(self):
118116
'''Get sizes of all parameters in `model`'''
119117
mods = self.named_modules()
120118
sizes = []
121119
for _, m in mods:
122120
p = list(m.parameters())
123-
modsz = []
124-
for j in range(len(p)):
125-
modsz.append(np.array(p[j].size()))
121+
modsz = [np.array(param.size()) for param in p]
126122
sizes.append(modsz)
127123

128124
self.param_sizes = sizes
129-
return
130125

131126
def get_parameter_nums(self):
132127
'''Get number of parameters in each layer'''
@@ -137,7 +132,6 @@ def get_parameter_nums(self):
137132
all_params += np.prod(p)
138133
param_nums.append(all_params)
139134
self.param_nums = param_nums
140-
return
141135

142136
def make_summary(self):
143137
'''

pytorch_lightning/core/saving.py

-4
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@ def on_load_checkpoint(self, checkpoint):
77
:param checkpoint:
88
:return:
99
"""
10-
pass
1110

1211
def on_save_checkpoint(self, checkpoint):
1312
"""
1413
Give the model a chance to add something to the checkpoint.
1514
state_dict is already there
1615
"""
17-
pass
1816

1917
# -------------------------
2018
# OPTIONAL HOOKS
@@ -24,11 +22,9 @@ def on_hpc_save(self, checkpoint):
2422
Hook to do whatever you need right before Slurm manager saves the model
2523
:return:
2624
"""
27-
pass
2825

2926
def on_hpc_load(self, checkpoint):
3027
"""
3128
Hook to do whatever you need right before Slurm manager loads the model
3229
:return:
3330
"""
34-
pass

pytorch_lightning/loggers/base.py

-3
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,15 @@ def log_hyperparams(self, params):
4343

4444
def save(self):
4545
"""Save log data."""
46-
pass
4746

4847
def finalize(self, status):
4948
"""Do any processing that is necessary to finalize an experiment.
5049
5150
:param status: Status that the experiment finished with (e.g. success, failed, aborted)
5251
"""
53-
pass
5452

5553
def close(self):
5654
"""Do any cleanup that is necessary to close an experiment."""
57-
pass
5855

5956
@property
6057
def rank(self):

pytorch_lightning/loggers/tensorboard.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,5 @@ def _get_next_version(self):
143143

144144
if len(existing_versions) == 0:
145145
return 0
146-
else:
147-
return max(existing_versions) + 1
146+
147+
return max(existing_versions) + 1

pytorch_lightning/overrides/data_parallel.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def get_a_var(obj): # pragma: no cover
2525
if isinstance(obj, torch.Tensor):
2626
return obj
2727

28-
if isinstance(obj, list) or isinstance(obj, tuple):
28+
if isinstance(obj, (list, tuple)):
2929
for result in map(get_a_var, obj):
3030
if isinstance(result, torch.Tensor):
3131
return result
@@ -56,10 +56,10 @@ def forward(self, *inputs, **kwargs):
5656
# lightning
5757
if self.module.training:
5858
return self.module.training_step(*inputs[0], **kwargs[0])
59-
elif self.module.testing:
59+
if self.module.testing:
6060
return self.module.test_step(*inputs[0], **kwargs[0])
61-
else:
62-
return self.module.validation_step(*inputs[0], **kwargs[0])
61+
62+
return self.module.validation_step(*inputs[0], **kwargs[0])
6363

6464
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
6565
outputs = self.parallel_apply(replicas, inputs, kwargs)

pytorch_lightning/trainer/distrib_data_parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
246246

247247
# when slurm is managing the task it sets the visible devices
248248
if not is_slurm_managing_tasks:
249-
if type(data_parallel_device_ids) is int:
249+
if isinstance(data_parallel_device_ids, int):
250250
id_str = ','.join(str(x) for x in list(range(data_parallel_device_ids)))
251251
os.environ["CUDA_VISIBLE_DEVICES"] = id_str
252252
else:

0 commit comments

Comments
 (0)