Skip to content

Commit 769a459

Browse files
authored
remove extra kwargs from Trainer init (#1820)
* remove kwargs * remove useless test * rename unknown trainer flag * trainer inheritance and test * blank line * test for unknown arg * changelog
1 parent 692f302 commit 769a459

File tree

5 files changed

+46
-26
lines changed

5 files changed

+46
-26
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6969
- Fixed `hparam` logging with metrics ([#1647](https://github.com/PyTorchLightning/pytorch-lightning/pull/1647))
7070

7171

72+
- Fixed an issue with Trainer constructor silently ignoring unkown/misspelled arguments ([#1820](https://github.com/PyTorchLightning/pytorch-lightning/pull/1820))
73+
7274
## [0.7.5] - 2020-04-27
7375

7476
### Changed

pytorch_lightning/trainer/callback_hook.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66

77
class TrainerCallbackHookMixin(ABC):
88

9-
def __init__(self):
10-
# this is just a summary on variables used in this abstract class,
11-
# the proper values/initialisation should be done in child class
12-
self.callbacks: List[Callback] = []
13-
self.get_model: Callable = ...
9+
# this is just a summary on variables used in this abstract class,
10+
# the proper values/initialisation should be done in child class
11+
callbacks: List[Callback] = []
12+
get_model: Callable = ...
1413

1514
def on_init_start(self):
1615
"""Called when the trainer initialization begins, model has not yet been set."""

pytorch_lightning/trainer/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def __init__(
142142
use_amp=None, # backward compatible, todo: remove in v0.9.0
143143
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
144144
nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0
145-
**kwargs
146145
):
147146
r"""
148147
@@ -305,6 +304,7 @@ def __init__(
305304
Additionally, can be set to either `power` that estimates the batch size through
306305
a power search or `binsearch` that estimates the batch size through a binary search.
307306
"""
307+
super().__init__()
308308

309309
self.deterministic = deterministic
310310
torch.backends.cudnn.deterministic = self.deterministic

tests/trainer/test_dataloaders.py

+1-20
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def test_inf_train_dataloader(tmpdir, check_interval):
289289
trainer = Trainer(
290290
default_root_dir=tmpdir,
291291
max_epochs=1,
292-
train_check_interval=check_interval,
292+
val_check_interval=check_interval
293293
)
294294
result = trainer.fit(model)
295295
# verify training completed
@@ -315,25 +315,6 @@ def test_inf_val_dataloader(tmpdir, check_interval):
315315
assert result == 1
316316

317317

318-
@pytest.mark.parametrize('check_interval', [50, 1.0])
319-
def test_inf_test_dataloader(tmpdir, check_interval):
320-
"""Test inf test data loader (e.g. IterableDataset)"""
321-
322-
model = EvalModelTemplate()
323-
model.test_dataloader = model.test_dataloader__infinite
324-
325-
# logger file to get meta
326-
trainer = Trainer(
327-
default_root_dir=tmpdir,
328-
max_epochs=1,
329-
test_check_interval=check_interval,
330-
)
331-
result = trainer.fit(model)
332-
333-
# verify training completed
334-
assert result == 1
335-
336-
337318
def test_error_on_zero_len_dataloader(tmpdir):
338319
""" Test that error is raised if a zero-length dataloader is defined """
339320

tests/trainer/test_trainer.py

+38
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,41 @@ def test_trainer_config(trainer_kwargs, expected):
772772
assert trainer.on_gpu is expected["on_gpu"]
773773
assert trainer.single_gpu is expected["single_gpu"]
774774
assert trainer.num_processes == expected["num_processes"]
775+
776+
777+
def test_trainer_subclassing():
778+
model = EvalModelTemplate()
779+
780+
# First way of pulling out args from signature is to list them
781+
class TrainerSubclass(Trainer):
782+
783+
def __init__(self, custom_arg, *args, custom_kwarg='test', **kwargs):
784+
super().__init__(*args, **kwargs)
785+
self.custom_arg = custom_arg
786+
self.custom_kwarg = custom_kwarg
787+
788+
trainer = TrainerSubclass(123, custom_kwarg='custom', fast_dev_run=True)
789+
result = trainer.fit(model)
790+
assert result == 1
791+
assert trainer.custom_arg == 123
792+
assert trainer.custom_kwarg == 'custom'
793+
assert trainer.fast_dev_run
794+
795+
# Second way is to pop from the dict
796+
# It's a special case because Trainer does not have any positional args
797+
class TrainerSubclass(Trainer):
798+
799+
def __init__(self, **kwargs):
800+
self.custom_arg = kwargs.pop('custom_arg', 0)
801+
self.custom_kwarg = kwargs.pop('custom_kwarg', 'test')
802+
super().__init__(**kwargs)
803+
804+
trainer = TrainerSubclass(custom_kwarg='custom', fast_dev_run=True)
805+
result = trainer.fit(model)
806+
assert result == 1
807+
assert trainer.custom_kwarg == 'custom'
808+
assert trainer.fast_dev_run
809+
810+
# when we pass in an unknown arg, the base class should complain
811+
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'abcdefg'") as e:
812+
TrainerSubclass(abcdefg='unknown_arg')

0 commit comments

Comments
 (0)