Skip to content

Commit 4c4b734

Browse files
committed
🐛 fix add_argparse_args
1 parent 0f10c54 commit 4c4b734

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

pytorch_lightning/core/datamodule.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88

99

1010
class _DataModuleWrapper(type):
11+
1112
def __call__(cls, *args, **kwargs):
1213
"""A wrapper for LightningDataModule that:
1314
1415
1. Runs user defined subclass's __init__
1516
2. Assures prepare_data() runs on rank 0
16-
3. Runs prepare_data()
17-
4. Runs setup()
1817
"""
1918

2019
# Get instance of LightningDataModule by mocking its __init__ via __call__
@@ -23,9 +22,6 @@ def __call__(cls, *args, **kwargs):
2322
# Wrap instance's prepare_data function with rank_zero_only and reassign to instance
2423
obj.prepare_data = rank_zero_only(obj.prepare_data)
2524

26-
# Run both prepare_data() and setup() post-init
27-
obj.prepare_data()
28-
obj.setup()
2925
return obj
3026

3127

@@ -278,7 +274,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
278274
List with tuples of 3 values:
279275
(argument name, set with argument types, argument default value).
280276
"""
281-
datamodule_default_params = inspect.signature(cls).parameters
277+
datamodule_default_params = inspect.signature(cls.__init__).parameters
282278
name_type_default = []
283279
for arg in datamodule_default_params:
284280
arg_type = datamodule_default_params[arg].annotation

0 commit comments

Comments
 (0)