Skip to content

Commit 1051c18

Browse files
elliotwaitewilliamFalcon
authored andcommitted
Simplify variables: step, epoch, max_epochs, min_epochs (#589)
1 parent c6e0dbe commit 1051c18

18 files changed

+90
-88
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,16 @@ use something other than tensorboard).
154154
Here are more advanced examples
155155
```python
156156
# train on cpu using only 10% of the data (for demo purposes)
157-
trainer = Trainer(max_num_epochs=1, train_percent_check=0.1)
157+
trainer = Trainer(max_epochs=1, train_percent_check=0.1)
158158

159159
# train on 4 gpus (lightning chooses GPUs for you)
160-
# trainer = Trainer(max_num_epochs=1, gpus=4, distributed_backend='ddp')
160+
# trainer = Trainer(max_epochs=1, gpus=4, distributed_backend='ddp')
161161

162162
# train on 4 gpus (you choose GPUs)
163-
# trainer = Trainer(max_num_epochs=1, gpus=[0, 1, 3, 7], distributed_backend='ddp')
163+
# trainer = Trainer(max_epochs=1, gpus=[0, 1, 3, 7], distributed_backend='ddp')
164164

165165
# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job)
166-
# trainer = Trainer(max_num_epochs=1, gpus=8, num_gpu_nodes=4, distributed_backend='ddp')
166+
# trainer = Trainer(max_epochs=1, gpus=8, num_gpu_nodes=4, distributed_backend='ddp')
167167

168168
# train (1 epoch only here for demo)
169169
trainer.fit(model)

pl_examples/full_examples/imagenet/imagenet_example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def main(hparams):
234234
trainer = pl.Trainer(
235235
default_save_path=hparams.save_path,
236236
gpus=hparams.gpus,
237-
max_num_epochs=hparams.epochs,
237+
max_epochs=hparams.epochs,
238238
distributed_backend=hparams.distributed_backend,
239239
use_amp=hparams.use_16bit
240240
)

pytorch_lightning/core/lightning.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -694,10 +694,10 @@ def configure_optimizers(self):
694694
"""
695695
raise NotImplementedError
696696

697-
def optimizer_step(self, epoch_idx, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
697+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
698698
"""Do something instead of the standard optimizer behavior
699699
700-
:param int epoch_idx:
700+
:param int epoch:
701701
:param int batch_idx:
702702
:param optimizer:
703703
:param optimizer_idx:

pytorch_lightning/logging/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def log_hyperparams(self, params):
5252
pass
5353
5454
@rank_zero_only
55-
def log_metrics(self, metrics, step_idx):
55+
def log_metrics(self, metrics, step):
5656
# metrics is a dictionary of metric names and values
5757
# your code to record metrics goes here
5858
pass

pytorch_lightning/logging/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ class LightningLoggerBase(object):
2121
def __init__(self):
2222
self._rank = 0
2323

24-
def log_metrics(self, metrics, step_idx):
24+
def log_metrics(self, metrics, step):
2525
"""Record metrics.
2626
2727
:param float metric: Dictionary with metric names as keys and measured quanties as values
28-
:param int|None step_idx: Step number at which the metrics should be recorded
28+
:param int|None step: Step number at which the metrics should be recorded
2929
"""
3030
raise NotImplementedError()
3131

pytorch_lightning/logging/comet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,13 @@ def log_hyperparams(self, params):
145145
self.experiment.log_parameters(vars(params))
146146

147147
@rank_zero_only
148-
def log_metrics(self, metrics, step_idx=None):
148+
def log_metrics(self, metrics, step=None):
149149
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
150150
for key, val in metrics.items():
151151
if is_tensor(val):
152152
metrics[key] = val.cpu().detach()
153153

154-
self.experiment.log_metrics(metrics, step=step_idx)
154+
self.experiment.log_metrics(metrics, step=step)
155155

156156
@rank_zero_only
157157
def finalize(self, status):

pytorch_lightning/logging/mlflow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ def log_hyperparams(self, params):
6868
self.experiment.log_param(self.run_id, k, v)
6969

7070
@rank_zero_only
71-
def log_metrics(self, metrics, step_idx=None):
71+
def log_metrics(self, metrics, step=None):
7272
timestamp_ms = int(time() * 1000)
7373
for k, v in metrics.items():
7474
if isinstance(v, str):
7575
logger.warning(
7676
f"Discarding metric with string value {k}={v}"
7777
)
7878
continue
79-
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step_idx)
79+
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
8080

8181
def save(self):
8282
pass

pytorch_lightning/logging/test_tube.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def log_hyperparams(self, params):
7676
self.experiment.argparse(params)
7777

7878
@rank_zero_only
79-
def log_metrics(self, metrics, step_idx=None):
79+
def log_metrics(self, metrics, step=None):
8080
# TODO: HACK figure out where this is being set to true
8181
self.experiment.debug = self.debug
82-
self.experiment.log(metrics, global_step=step_idx)
82+
self.experiment.log(metrics, global_step=step)
8383

8484
@rank_zero_only
8585
def save(self):

pytorch_lightning/trainer/logging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def log_metrics(self, metrics, grad_norm_dic):
4343

4444
# log actual metrics
4545
if self.proc_rank == 0 and self.logger is not None:
46-
self.logger.log_metrics(scalar_metrics, step_idx=self.global_step)
46+
self.logger.log_metrics(scalar_metrics, step=self.global_step)
4747
self.logger.save()
4848

4949
def add_tqdm_metrics(self, metrics):

pytorch_lightning/trainer/trainer.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def __init__(
7272
accumulate_grad_batches=1,
7373
max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
7474
min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
75-
max_num_epochs=1000,
76-
min_num_epochs=1,
75+
max_epochs=1000,
76+
min_epochs=1,
7777
train_percent_check=1.0,
7878
val_percent_check=1.0,
7979
test_percent_check=1.0,
@@ -111,8 +111,8 @@ def __init__(
111111
:param int check_val_every_n_epoch: check val every n train epochs
112112
:param bool fast_dev_run: runs full iteration over everything to find bugs
113113
:param int accumulate_grad_batches: Accumulates grads every k batches
114-
:param int max_num_epochs:
115-
:param int min_num_epochs:
114+
:param int max_epochs:
115+
:param int min_epochs:
116116
:param int train_percent_check: How much of train set to check
117117
:param int val_percent_check: How much of val set to check
118118
:param int test_percent_check: How much of test set to check
@@ -158,17 +158,17 @@ def __init__(
158158
self.process_position = process_position
159159
self.weights_summary = weights_summary
160160
if max_nb_epochs is not None: # Backward compatibility
161-
warnings.warn("`max_nb_epochs` has renamed to `max_num_epochs` since v0.5.0"
161+
warnings.warn("`max_nb_epochs` has renamed to `max_epochs` since v0.5.0"
162162
" and will be removed in v0.8.0", DeprecationWarning)
163-
if not max_num_epochs: # in case you did not set the proper value
164-
max_num_epochs = max_nb_epochs
165-
self.max_num_epochs = max_num_epochs
163+
if not max_epochs: # in case you did not set the proper value
164+
max_epochs = max_nb_epochs
165+
self.max_epochs = max_epochs
166166
if min_nb_epochs is not None: # Backward compatibility
167-
warnings.warn("`min_nb_epochs` has renamed to `min_num_epochs` since v0.5.0"
167+
warnings.warn("`min_nb_epochs` has renamed to `min_epochs` since v0.5.0"
168168
" and will be removed in v0.8.0", DeprecationWarning)
169-
if not min_num_epochs: # in case you did not set the proper value
170-
min_num_epochs = min_nb_epochs
171-
self.min_num_epochs = min_num_epochs
169+
if not min_epochs: # in case you did not set the proper value
170+
min_epochs = min_nb_epochs
171+
self.min_epochs = min_epochs
172172
if nb_sanity_val_steps is not None: # Backward compatibility
173173
warnings.warn("`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
174174
" and will be removed in v0.8.0", DeprecationWarning)
@@ -183,7 +183,7 @@ def __init__(
183183
self.fast_dev_run = fast_dev_run
184184
if self.fast_dev_run:
185185
self.num_sanity_val_steps = 1
186-
self.max_num_epochs = 1
186+
self.max_epochs = 1
187187
m = '''
188188
Running in fast_dev_run mode: will run a full train,
189189
val loop using a single batch

pytorch_lightning/trainer/training_loop.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
.. code-block:: python
2424
2525
# DEFAULT
26-
trainer = Trainer(min_num_epochs=1, max_num_epochs=1000)
26+
trainer = Trainer(min_epochs=1, max_epochs=1000)
2727
2828
Early stopping
2929
--------------
@@ -259,17 +259,17 @@ def process_output(self, output, train):
259259

260260
def train(self):
261261
# run all epochs
262-
for epoch_idx in range(self.current_epoch, self.max_num_epochs):
262+
for epoch in range(self.current_epoch, self.max_epochs):
263263
# set seed for distributed sampler (enables shuffling for each epoch)
264264
if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
265-
self.get_train_dataloader().sampler.set_epoch(epoch_idx)
265+
self.get_train_dataloader().sampler.set_epoch(epoch)
266266

267267
# get model
268268
model = self.get_model()
269269

270270
# update training progress in trainer and model
271-
model.current_epoch = epoch_idx
272-
self.current_epoch = epoch_idx
271+
model.current_epoch = epoch
272+
self.current_epoch = epoch
273273

274274
# val can be checked multiple times in epoch
275275
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
@@ -294,11 +294,11 @@ def train(self):
294294
# .reset() doesn't work on disabled progress bar so we should check
295295
if not self.main_progress_bar.disable:
296296
self.main_progress_bar.reset(num_iterations)
297-
desc = f'Epoch {epoch_idx + 1}' if not self.is_iterable_train_dataloader else ''
297+
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
298298
self.main_progress_bar.set_description(desc)
299299

300300
# changing gradient according accumulation_scheduler
301-
self.accumulation_scheduler.on_epoch_begin(epoch_idx, self)
301+
self.accumulation_scheduler.on_epoch_begin(epoch, self)
302302

303303
# -----------------
304304
# RUN TNG EPOCH
@@ -319,9 +319,9 @@ def train(self):
319319
self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch)
320320

321321
# early stopping
322-
met_min_epochs = epoch_idx > self.min_num_epochs
322+
met_min_epochs = epoch > self.min_epochs
323323
if self.enable_early_stop and (met_min_epochs or self.fast_dev_run):
324-
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_idx,
324+
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch,
325325
logs=self.callback_metrics)
326326
# stop training
327327
stop = should_stop and met_min_epochs

pytorch_lightning/utilities/arg_parse.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None
1515
parser.opt_list('--accumulate_grad_batches', default=1, type=int, tunable=False,
1616
help='accumulates gradients k times before applying update.'
1717
' Simulates huge batch size')
18-
parser.add_argument('--max_num_epochs', default=200, type=int, help='cap epochs')
19-
parser.add_argument('--min_num_epochs', default=2, type=int, help='min epochs')
18+
parser.add_argument('--max_epochs', default=200, type=int,
19+
help='maximum number of epochs')
20+
parser.add_argument('--min_epochs', default=2, type=int,
21+
help='minimum number of epochs')
2022
parser.add_argument('--train_percent_check', default=1.0, type=float,
2123
help='how much of training set to check')
2224
parser.add_argument('--val_percent_check', default=1.0, type=float,

tests/test_amp.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_amp_single_gpu(tmpdir):
2323
trainer_options = dict(
2424
default_save_path=tmpdir,
2525
show_progress_bar=True,
26-
max_num_epochs=1,
26+
max_epochs=1,
2727
gpus=1,
2828
distributed_backend='ddp',
2929
use_amp=True
@@ -45,7 +45,7 @@ def test_no_amp_single_gpu(tmpdir):
4545
trainer_options = dict(
4646
default_save_path=tmpdir,
4747
show_progress_bar=True,
48-
max_num_epochs=1,
48+
max_epochs=1,
4949
gpus=1,
5050
distributed_backend='dp',
5151
use_amp=True
@@ -69,7 +69,7 @@ def test_amp_gpu_ddp(tmpdir):
6969
trainer_options = dict(
7070
default_save_path=tmpdir,
7171
show_progress_bar=True,
72-
max_num_epochs=1,
72+
max_epochs=1,
7373
gpus=2,
7474
distributed_backend='ddp',
7575
use_amp=True
@@ -94,7 +94,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
9494

9595
trainer_options = dict(
9696
show_progress_bar=True,
97-
max_num_epochs=1,
97+
max_epochs=1,
9898
gpus=[0],
9999
distributed_backend='ddp',
100100
use_amp=True
@@ -153,7 +153,7 @@ def test_cpu_model_with_amp(tmpdir):
153153
default_save_path=tmpdir,
154154
show_progress_bar=False,
155155
logger=tutils.get_test_tube_logger(tmpdir),
156-
max_num_epochs=1,
156+
max_epochs=1,
157157
train_percent_check=0.4,
158158
val_percent_check=0.4,
159159
use_amp=True
@@ -175,7 +175,7 @@ def test_amp_gpu_dp(tmpdir):
175175
model, hparams = tutils.get_model()
176176
trainer_options = dict(
177177
default_save_path=tmpdir,
178-
max_num_epochs=1,
178+
max_epochs=1,
179179
gpus='0, 1', # test init with gpu string
180180
distributed_backend='dp',
181181
use_amp=True

tests/test_cpu_models.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_lbfgs_cpu_model(tmpdir):
4646

4747
trainer_options = dict(
4848
default_save_path=tmpdir,
49-
max_num_epochs=1,
49+
max_epochs=1,
5050
print_nan_grads=True,
5151
show_progress_bar=False,
5252
weights_summary='top',
@@ -64,7 +64,7 @@ def test_default_logger_callbacks_cpu_model(tmpdir):
6464

6565
trainer_options = dict(
6666
default_save_path=tmpdir,
67-
max_num_epochs=1,
67+
max_epochs=1,
6868
gradient_clip_val=1.0,
6969
overfit_pct=0.20,
7070
print_nan_grads=True,
@@ -97,7 +97,7 @@ def test_running_test_after_fitting(tmpdir):
9797
trainer_options = dict(
9898
default_save_path=tmpdir,
9999
show_progress_bar=False,
100-
max_num_epochs=1,
100+
max_epochs=1,
101101
train_percent_check=0.4,
102102
val_percent_check=0.2,
103103
test_percent_check=0.2,
@@ -135,7 +135,7 @@ class CurrentTestModel(LightningTestMixin, LightningTestModelBase):
135135

136136
trainer_options = dict(
137137
show_progress_bar=False,
138-
max_num_epochs=1,
138+
max_epochs=1,
139139
train_percent_check=0.4,
140140
val_percent_check=0.2,
141141
test_percent_check=0.2,
@@ -209,7 +209,7 @@ def test_simple_cpu(tmpdir):
209209
# logger file to get meta
210210
trainer_options = dict(
211211
default_save_path=tmpdir,
212-
max_num_epochs=1,
212+
max_epochs=1,
213213
val_percent_check=0.1,
214214
train_percent_check=0.1,
215215
)
@@ -230,7 +230,7 @@ def test_cpu_model(tmpdir):
230230
default_save_path=tmpdir,
231231
show_progress_bar=False,
232232
logger=tutils.get_test_tube_logger(tmpdir),
233-
max_num_epochs=1,
233+
max_epochs=1,
234234
train_percent_check=0.4,
235235
val_percent_check=0.4
236236
)
@@ -253,7 +253,7 @@ def test_all_features_cpu_model(tmpdir):
253253
show_progress_bar=False,
254254
logger=tutils.get_test_tube_logger(tmpdir),
255255
accumulate_grad_batches=2,
256-
max_num_epochs=1,
256+
max_epochs=1,
257257
train_percent_check=0.4,
258258
val_percent_check=0.4
259259
)
@@ -314,7 +314,7 @@ def train_dataloader(self):
314314

315315
trainer_options = dict(
316316
default_save_path=tmpdir,
317-
max_num_epochs=1,
317+
max_epochs=1,
318318
truncated_bptt_steps=truncated_bptt_steps,
319319
val_percent_check=0,
320320
weights_summary=None,
@@ -348,7 +348,7 @@ def test_single_gpu_model(tmpdir):
348348
trainer_options = dict(
349349
default_save_path=tmpdir,
350350
show_progress_bar=False,
351-
max_num_epochs=1,
351+
max_epochs=1,
352352
train_percent_check=0.1,
353353
val_percent_check=0.1,
354354
gpus=1

0 commit comments

Comments
 (0)