@@ -772,3 +772,39 @@ def test_trainer_config(trainer_kwargs, expected):
772
772
assert trainer .on_gpu is expected ["on_gpu" ]
773
773
assert trainer .single_gpu is expected ["single_gpu" ]
774
774
assert trainer .num_processes == expected ["num_processes" ]
775
+
776
+
777
+ def test_trainer_subclassing ():
778
+
779
+ model = EvalModelTemplate ()
780
+
781
+ # First way of pulling out args from signature is to list them
782
+ class TrainerSubclass (Trainer ):
783
+
784
+ def __init__ (self , custom_arg , * args , custom_kwarg = 'test' , ** kwargs ):
785
+ super ().__init__ (* args , ** kwargs )
786
+ self .custom_arg = custom_arg
787
+ self .custom_kwarg = custom_kwarg
788
+
789
+ trainer = TrainerSubclass (123 , custom_kwarg = 'custom' , fast_dev_run = True )
790
+ result = trainer .fit (model )
791
+ assert result == 1
792
+ assert trainer .custom_arg == 123
793
+ assert trainer .custom_kwarg == 'custom'
794
+ assert trainer .fast_dev_run
795
+
796
+ # Second way is to pop from the dict
797
+ # It's a special case because Trainer does not have any positional args
798
+ class TrainerSubclass (Trainer ):
799
+
800
+ def __init__ (self , ** kwargs ):
801
+ self .custom_arg = kwargs .pop ('custom_arg' , 0 )
802
+ self .custom_kwarg = kwargs .pop ('custom_kwarg' , 'test' )
803
+ super ().__init__ (** kwargs )
804
+
805
+ trainer = TrainerSubclass (custom_kwarg = 'custom' , fast_dev_run = True )
806
+ result = trainer .fit (model )
807
+ assert result == 1
808
+ assert trainer .custom_kwarg == 'custom'
809
+ assert trainer .fast_dev_run
810
+
0 commit comments