Skip to content

Commit 329c264

Browse files
committed
trainer inheritance and test
1 parent 2abf076 commit 329c264

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

pytorch_lightning/trainer/callback_hook.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66

77
class TrainerCallbackHookMixin(ABC):
88

9-
def __init__(self):
10-
# this is just a summary on variables used in this abstract class,
11-
# the proper values/initialisation should be done in child class
12-
self.callbacks: List[Callback] = []
13-
self.get_model: Callable = ...
9+
# this is just a summary on variables used in this abstract class,
10+
# the proper values/initialisation should be done in child class
11+
callbacks: List[Callback] = []
12+
get_model: Callable = ...
1413

1514
def on_init_start(self):
1615
"""Called when the trainer initialization begins, model has not yet been set."""

pytorch_lightning/trainer/trainer.py

+1
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def __init__(
304304
Additionally, can be set to either `power` that estimates the batch size through
305305
a power search or `binsearch` that estimates the batch size through a binary search.
306306
"""
307+
super().__init__()
307308

308309
self.deterministic = deterministic
309310
torch.backends.cudnn.deterministic = self.deterministic

tests/trainer/test_trainer.py

+36
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,39 @@ def test_trainer_config(trainer_kwargs, expected):
772772
assert trainer.on_gpu is expected["on_gpu"]
773773
assert trainer.single_gpu is expected["single_gpu"]
774774
assert trainer.num_processes == expected["num_processes"]
775+
776+
777+
def test_trainer_subclassing():
778+
779+
model = EvalModelTemplate()
780+
781+
# First way of pulling out args from signature is to list them
782+
class TrainerSubclass(Trainer):
783+
784+
def __init__(self, custom_arg, *args, custom_kwarg='test', **kwargs):
785+
super().__init__(*args, **kwargs)
786+
self.custom_arg = custom_arg
787+
self.custom_kwarg = custom_kwarg
788+
789+
trainer = TrainerSubclass(123, custom_kwarg='custom', fast_dev_run=True)
790+
result = trainer.fit(model)
791+
assert result == 1
792+
assert trainer.custom_arg == 123
793+
assert trainer.custom_kwarg == 'custom'
794+
assert trainer.fast_dev_run
795+
796+
# Second way is to pop from the dict
797+
# It's a special case because Trainer does not have any positional args
798+
class TrainerSubclass(Trainer):
799+
800+
def __init__(self, **kwargs):
801+
self.custom_arg = kwargs.pop('custom_arg', 0)
802+
self.custom_kwarg = kwargs.pop('custom_kwarg', 'test')
803+
super().__init__(**kwargs)
804+
805+
trainer = TrainerSubclass(custom_kwarg='custom', fast_dev_run=True)
806+
result = trainer.fit(model)
807+
assert result == 1
808+
assert trainer.custom_kwarg == 'custom'
809+
assert trainer.fast_dev_run
810+

0 commit comments

Comments
 (0)