@@ -772,3 +772,41 @@ def test_trainer_config(trainer_kwargs, expected):
772
772
assert trainer .on_gpu is expected ["on_gpu" ]
773
773
assert trainer .single_gpu is expected ["single_gpu" ]
774
774
assert trainer .num_processes == expected ["num_processes" ]
775
+
776
+
777
+ def test_trainer_subclassing ():
778
+ model = EvalModelTemplate ()
779
+
780
+ # First way of pulling out args from signature is to list them
781
+ class TrainerSubclass (Trainer ):
782
+
783
+ def __init__ (self , custom_arg , * args , custom_kwarg = 'test' , ** kwargs ):
784
+ super ().__init__ (* args , ** kwargs )
785
+ self .custom_arg = custom_arg
786
+ self .custom_kwarg = custom_kwarg
787
+
788
+ trainer = TrainerSubclass (123 , custom_kwarg = 'custom' , fast_dev_run = True )
789
+ result = trainer .fit (model )
790
+ assert result == 1
791
+ assert trainer .custom_arg == 123
792
+ assert trainer .custom_kwarg == 'custom'
793
+ assert trainer .fast_dev_run
794
+
795
+ # Second way is to pop from the dict
796
+ # It's a special case because Trainer does not have any positional args
797
+ class TrainerSubclass (Trainer ):
798
+
799
+ def __init__ (self , ** kwargs ):
800
+ self .custom_arg = kwargs .pop ('custom_arg' , 0 )
801
+ self .custom_kwarg = kwargs .pop ('custom_kwarg' , 'test' )
802
+ super ().__init__ (** kwargs )
803
+
804
+ trainer = TrainerSubclass (custom_kwarg = 'custom' , fast_dev_run = True )
805
+ result = trainer .fit (model )
806
+ assert result == 1
807
+ assert trainer .custom_kwarg == 'custom'
808
+ assert trainer .fast_dev_run
809
+
810
+ # when we pass in an unknown arg, the base class should complain
811
+ with pytest .raises (TypeError , match = r"__init__\(\) got an unexpected keyword argument 'abcdefg'" ) as e :
812
+ TrainerSubclass (abcdefg = 'unknown_arg' )
0 commit comments