Skip to content

Commit 040c1f2

Browse files
Merge pull request #1636 from PyTorchLightning/callback
test pickling
2 parents 013fd98 + e7ea564 commit 040c1f2

File tree

5 files changed

+34
-5
lines changed

5 files changed

+34
-5
lines changed

pytorch_lightning/callbacks/early_stopping.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
5757
self.min_delta = min_delta
5858
self.wait = 0
5959
self.stopped_epoch = 0
60+
self.mode = mode
6061

6162
mode_dict = {
6263
'min': torch.lt,
@@ -67,9 +68,8 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
6768
if mode not in mode_dict:
6869
if self.verbose > 0:
6970
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
70-
mode = 'auto'
71+
self.mode = 'auto'
7172

72-
self.monitor_op = mode_dict[mode]
7373
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
7474

7575
def _validate_condition_metric(self, logs):
@@ -94,6 +94,15 @@ def _validate_condition_metric(self, logs):
9494

9595
return True
9696

97+
@property
98+
def monitor_op(self):
99+
mode_dict = {
100+
'min': torch.lt,
101+
'max': torch.gt,
102+
'auto': torch.gt if 'acc' in self.monitor else torch.lt
103+
}
104+
return mode_dict[self.mode]
105+
97106
def on_train_start(self, trainer, pl_module):
98107
# Allow instances to be re-used
99108
self.wait = 0

tests/callbacks/test_callbacks.py

+9
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,15 @@ def training_step(self, *args, **kwargs):
240240
assert trainer.current_epoch < trainer.max_epochs
241241

242242

243+
def test_pickling(tmpdir):
244+
import pickle
245+
early_stopping = EarlyStopping()
246+
ckpt = ModelCheckpoint(tmpdir)
247+
248+
pickle.dumps(ckpt)
249+
pickle.dumps(early_stopping)
250+
251+
243252
def test_model_checkpoint_with_non_string_input(tmpdir):
244253
""" Test that None in checkpoint callback is valid and that chkp_path is
245254
set correctly """

tests/loggers/test_all.py

+3
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
8888
logger_args = _get_logger_args(logger_class, tmpdir)
8989
logger = logger_class(**logger_args)
9090

91+
# test pickling loggers
92+
pickle.dumps(logger)
93+
9194
trainer = Trainer(
9295
max_epochs=1,
9396
logger=logger

tests/trainer/test_trainer.py

+7
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
)
2727

2828

29+
def test_model_pickle(tmpdir):
30+
import pickle
31+
32+
model = TestModelBase(tutils.get_default_hparams())
33+
pickle.dumps(model)
34+
35+
2936
def test_hparams_save_load(tmpdir):
3037
model = DictHparamsModel({'in_features': 28 * 28, 'out_features': 10, 'failed_key': lambda x: x})
3138

tests/trainer/test_trainer_cli.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
from argparse import ArgumentParser, Namespace
33
from unittest import mock
4+
import pickle
45

56
import pytest
67

@@ -42,14 +43,14 @@ def test_add_argparse_args_redefined(cli_args):
4243

4344
args = parser.parse_args(cli_args)
4445

46+
# make sure we can pickle args
47+
pickle.dumps(args)
48+
4549
# Check few deprecated args are not in namespace:
4650
for depr_name in ('gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs'):
4751
assert depr_name not in args
4852

4953
trainer = Trainer.from_argparse_args(args=args)
50-
51-
# make sure trainer can be pickled
52-
import pickle
5354
pickle.dumps(trainer)
5455

5556
assert isinstance(trainer, Trainer)

0 commit comments

Comments
 (0)