Skip to content

Commit edd4a87

Browse files
authored
Refactor callbacks (#776)
* Refactor callbacks * flake8 * Update docstrings * Simplified callback, protected trainer * .set_trainer() check * update docs * missed super().__ini__() * Updated tests * Use uppercase * refine checkpoint callback tests * Added test_begin() and test_end()
1 parent 27bba1a commit edd4a87

File tree

6 files changed

+120
-75
lines changed

6 files changed

+120
-75
lines changed

pytorch_lightning/callbacks/pt_callbacks.py

+55-52
Original file line numberDiff line numberDiff line change
@@ -12,60 +12,62 @@
1212

1313
import numpy as np
1414

15-
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
16-
1715

1816
class Callback(object):
19-
r"""Abstract base class used to build new callbacks.
20-
"""
17+
"""Abstract base class used to build new callbacks."""
2118

2219
def __init__(self):
23-
self.validation_data = None
24-
self.model = None
25-
26-
def set_params(self, params):
27-
self.params = params
20+
self._trainer = None
2821

29-
def set_model(self, model):
30-
if isinstance(model, LightningDistributedDataParallel):
31-
model = model.module
32-
self.model = model
22+
def set_trainer(self, trainer):
23+
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
24+
`trainer.batch_idx`, `trainer.global_step` can be used."""
25+
self._trainer = trainer
3326

34-
def on_epoch_begin(self, epoch, logs=None):
35-
"""
36-
called when the epoch begins
27+
def on_epoch_begin(self):
28+
"""Called when the epoch begins."""
29+
pass
3730

38-
Args:
39-
epoch (int): current epoch
40-
logs (dict): key-value pairs of quantities to monitor
31+
def on_epoch_end(self):
32+
"""Called when the epoch ends."""
33+
pass
4134

42-
Example:
35+
def on_batch_begin(self):
36+
"""Called when the training batch begins."""
37+
pass
4338

44-
on_epoch_begin(epoch=2, logs={'val_loss': 0.2})
45-
"""
39+
def on_batch_end(self):
40+
"""Called when the training batch ends."""
41+
pass
4642

47-
def on_epoch_end(self, epoch, logs=None):
43+
def on_train_begin(self):
44+
"""Called when the train begins."""
4845
pass
4946

50-
def on_batch_begin(self, batch, logs=None):
51-
"""
52-
called when the batch starts.
47+
def on_train_end(self):
48+
"""Called when the train ends."""
49+
pass
5350

54-
Args:
55-
batch (Tensor): current batch tensor
56-
logs (dict): key-value pairs of quantities to monitor
57-
"""
51+
def on_validation_begin(self):
52+
"""Called when the validation loop begins."""
53+
pass
5854

59-
def on_batch_end(self, batch, logs=None):
55+
def on_validation_end(self):
56+
"""Called when the validation loop ends."""
6057
pass
6158

62-
def on_train_begin(self, logs=None):
59+
def on_test_begin(self):
60+
"""Called when the test begins."""
6361
pass
6462

65-
def on_train_end(self, logs=None):
63+
def on_test_end(self):
64+
"""Called when the test ends."""
6665
pass
6766

6867

68+
_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"
69+
70+
6971
class EarlyStopping(Callback):
7072
r"""
7173
Stop training when a monitored quantity has stopped improving.
@@ -148,13 +150,16 @@ def check_metrics(self, logs):
148150

149151
return True
150152

151-
def on_train_begin(self, logs=None):
153+
def on_train_begin(self):
152154
# Allow instances to be re-used
153155
self.wait = 0
154156
self.stopped_epoch = 0
155157
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
156158

157-
def on_epoch_end(self, epoch, logs=None):
159+
def on_epoch_end(self):
160+
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
161+
162+
logs = self._trainer.callback_metrics
158163
stop_training = False
159164
if not self.check_metrics(logs):
160165
return stop_training
@@ -166,13 +171,13 @@ def on_epoch_end(self, epoch, logs=None):
166171
else:
167172
self.wait += 1
168173
if self.wait >= self.patience:
169-
self.stopped_epoch = epoch
174+
self.stopped_epoch = self._trainer.current_epoch
170175
stop_training = True
171176
self.on_train_end()
172177

173178
return stop_training
174179

175-
def on_train_end(self, logs=None):
180+
def on_train_end(self):
176181
if self.stopped_epoch > 0 and self.verbose > 0:
177182
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
178183
' but will start from "0" in v0.8.0.', DeprecationWarning)
@@ -306,8 +311,11 @@ def check_monitor_top_k(self, current):
306311
return True
307312
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
308313

309-
def on_epoch_end(self, epoch, logs=None):
310-
logs = logs or {}
314+
def on_validation_end(self):
315+
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
316+
317+
logs = self._trainer.callback_metrics
318+
epoch = self._trainer.current_epoch
311319
self.epochs_since_last_check += 1
312320

313321
if self.save_top_k == 0:
@@ -389,6 +397,8 @@ class GradientAccumulationScheduler(Callback):
389397
"""
390398

391399
def __init__(self, scheduling: dict):
400+
super().__init__()
401+
392402
if scheduling == {}: # empty dict error
393403
raise TypeError("Empty dict cannot be interpreted correct")
394404

@@ -408,21 +418,14 @@ def __init__(self, scheduling: dict):
408418
self.scheduling = scheduling
409419
self.epochs = sorted(scheduling.keys())
410420

411-
def on_epoch_begin(self, epoch, trainer):
421+
def on_epoch_begin(self):
422+
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
423+
424+
trainer = self._trainer
412425
# indexing epochs from 1 (until v0.6.x)
413-
# In v0.8.0, `epoch += 1` should be removed.
414-
epoch += 1
426+
# In v0.8.0, ` + 1` should be removed.
427+
epoch = trainer.current_epoch + 1
415428
for i in reversed(range(len(self.epochs))):
416429
if epoch >= self.epochs[i]:
417430
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
418431
break
419-
420-
421-
# if __name__ == '__main__':
422-
# c = EarlyStopping(min_delta=0.9, patience=2, verbose=True)
423-
# losses = [10, 9, 8, 8, 6, 4.3, 5, 4.4, 2.8, 2.5]
424-
# for i, loss in enumerate(losses):
425-
# should_stop = c.on_epoch_end(i, logs={'val_loss': loss})
426-
# log.info(loss)
427-
# if should_stop:
428-
# break

pytorch_lightning/trainer/callback_config.py

+6
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def configure_checkpoint_callback(self):
4848
# if checkpoint callback used, then override the weights path
4949
self.weights_save_path = self.checkpoint_callback.filepath
5050

51+
# link to the trainer
52+
self.checkpoint_callback.set_trainer(self)
53+
5154
# if weights_save_path is still none here, set to current working dir
5255
if self.weights_save_path is None:
5356
self.weights_save_path = self.default_save_path
@@ -77,3 +80,6 @@ def configure_early_stopping(self, early_stop_callback):
7780
else:
7881
self.early_stop_callback = early_stop_callback
7982
self.enable_early_stop = True
83+
84+
if self.early_stop_callback is not None:
85+
self.early_stop_callback.set_trainer(self)

pytorch_lightning/trainer/evaluation_loop.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,7 @@ def run_evaluation(self, test=False):
330330

331331
# model checkpointing
332332
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
333-
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
334-
logs=self.callback_metrics)
333+
self.checkpoint_callback.on_validation_end()
335334

336335
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
337336
# make dataloader_idx arg in validation_step optional

pytorch_lightning/trainer/training_loop.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def train(self):
328328
self.main_progress_bar.set_description(desc)
329329

330330
# changing gradient according accumulation_scheduler
331-
self.accumulation_scheduler.on_epoch_begin(epoch, self)
331+
self.accumulation_scheduler.on_epoch_begin()
332332

333333
# -----------------
334334
# RUN TNG EPOCH
@@ -352,8 +352,7 @@ def train(self):
352352
met_min_epochs = epoch >= self.min_epochs - 1
353353
if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
354354
(met_min_epochs or self.fast_dev_run)):
355-
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch,
356-
logs=self.callback_metrics)
355+
should_stop = self.early_stop_callback.on_epoch_end()
357356
# stop training
358357
stop = should_stop and met_min_epochs
359358
if stop:

pytorch_lightning/trainer/training_tricks.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,5 @@ def configure_accumulated_gradients(self, accumulate_grad_batches):
3939
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
4040
else:
4141
raise TypeError("Gradient accumulation supports only int and dict types")
42+
43+
self.accumulation_scheduler.set_trainer(self)

tests/test_trainer.py

+54-18
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,16 @@ def mock_save_function(filepath):
229229

230230
# -----------------
231231
# CASE K=-1 (all)
232-
w = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
233-
w.save_function = mock_save_function
232+
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
233+
checkpoint_callback.save_function = mock_save_function
234+
trainer = Trainer()
235+
checkpoint_callback.set_trainer(trainer)
236+
237+
# emulate callback's calls during the training
234238
for i, loss in enumerate(losses):
235-
w.on_epoch_end(i, logs={'val_loss': loss})
239+
checkpoint_callback._trainer.current_epoch = i
240+
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
241+
checkpoint_callback.on_validation_end()
236242

237243
file_lists = set(os.listdir(save_dir))
238244

@@ -247,10 +253,16 @@ def mock_save_function(filepath):
247253

248254
# -----------------
249255
# CASE K=0 (none)
250-
w = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
251-
w.save_function = mock_save_function
256+
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
257+
checkpoint_callback.save_function = mock_save_function
258+
trainer = Trainer()
259+
checkpoint_callback.set_trainer(trainer)
260+
261+
# emulate callback's calls during the training
252262
for i, loss in enumerate(losses):
253-
w.on_epoch_end(i, logs={'val_loss': loss})
263+
checkpoint_callback._trainer.current_epoch = i
264+
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
265+
checkpoint_callback.on_validation_end()
254266

255267
file_lists = os.listdir(save_dir)
256268

@@ -261,10 +273,16 @@ def mock_save_function(filepath):
261273

262274
# -----------------
263275
# CASE K=1 (2.5, epoch 4)
264-
w = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix')
265-
w.save_function = mock_save_function
276+
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix')
277+
checkpoint_callback.save_function = mock_save_function
278+
trainer = Trainer()
279+
checkpoint_callback.set_trainer(trainer)
280+
281+
# emulate callback's calls during the training
266282
for i, loss in enumerate(losses):
267-
w.on_epoch_end(i, logs={'val_loss': loss})
283+
checkpoint_callback._trainer.current_epoch = i
284+
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
285+
checkpoint_callback.on_validation_end()
268286

269287
file_lists = set(os.listdir(save_dir))
270288

@@ -278,11 +296,17 @@ def mock_save_function(filepath):
278296
# CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
279297
# make sure other files don't get deleted
280298

281-
w = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
299+
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
282300
open(f'{save_dir}/other_file.ckpt', 'a').close()
283-
w.save_function = mock_save_function
301+
checkpoint_callback.save_function = mock_save_function
302+
trainer = Trainer()
303+
checkpoint_callback.set_trainer(trainer)
304+
305+
# emulate callback's calls during the training
284306
for i, loss in enumerate(losses):
285-
w.on_epoch_end(i, logs={'val_loss': loss})
307+
checkpoint_callback._trainer.current_epoch = i
308+
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
309+
checkpoint_callback.on_validation_end()
286310

287311
file_lists = set(os.listdir(save_dir))
288312

@@ -298,10 +322,16 @@ def mock_save_function(filepath):
298322
# CASE K=4 (save all 4 models)
299323
# multiple checkpoints within same epoch
300324

301-
w = ModelCheckpoint(save_dir, save_top_k=4, verbose=1)
302-
w.save_function = mock_save_function
325+
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=4, verbose=1)
326+
checkpoint_callback.save_function = mock_save_function
327+
trainer = Trainer()
328+
checkpoint_callback.set_trainer(trainer)
329+
330+
# emulate callback's calls during the training
303331
for loss in losses:
304-
w.on_epoch_end(0, logs={'val_loss': loss})
332+
checkpoint_callback._trainer.current_epoch = 0
333+
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
334+
checkpoint_callback.on_validation_end()
305335

306336
file_lists = set(os.listdir(save_dir))
307337

@@ -314,10 +344,16 @@ def mock_save_function(filepath):
314344
# CASE K=3 (save the 2nd, 3rd, 4th model)
315345
# multiple checkpoints within same epoch
316346

317-
w = ModelCheckpoint(save_dir, save_top_k=3, verbose=1)
318-
w.save_function = mock_save_function
347+
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=3, verbose=1)
348+
checkpoint_callback.save_function = mock_save_function
349+
trainer = Trainer()
350+
checkpoint_callback.set_trainer(trainer)
351+
352+
# emulate callback's calls during the training
319353
for loss in losses:
320-
w.on_epoch_end(0, logs={'val_loss': loss})
354+
checkpoint_callback._trainer.current_epoch = 0
355+
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
356+
checkpoint_callback.on_validation_end()
321357

322358
file_lists = set(os.listdir(save_dir))
323359

0 commit comments

Comments
 (0)