Skip to content

Commit 081b2a8

Browse files
committed
🚧 .
1 parent ea314f4 commit 081b2a8

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

pytorch_lightning/core/datamodule.py

+6
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ def prepare_data(self):
101101
cache_imagenet()
102102
"""
103103

104+
@abstractmethod
105+
def setup(self, stage):
106+
"""
107+
Use this to make assignments to the class.
108+
"""
109+
104110
@abstractmethod
105111
def train_dataloader(self, *args, **kwargs) -> DataLoader:
106112
"""

pytorch_lightning/trainer/model_hooks.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
from abc import ABC, abstractmethod
33

4+
from pytorch_lightning.core.datamodule import LightningDataModule
45
from pytorch_lightning.core.lightning import LightningModule
56

67

@@ -15,7 +16,9 @@ def is_function_implemented(self, f_name, model=None):
1516
def is_overridden(self, method_name: str, model: LightningModule = None) -> bool:
1617
if model is None:
1718
model = self.get_model()
18-
super_object = LightningModule
19+
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
20+
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
21+
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule
1922

2023
# assert model, 'no model passes'
2124

pytorch_lightning/trainer/trainer.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch.utils.data import DataLoader
1010

1111
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
12+
from pytorch_lightning.core.datamodule import LightningDataModule
1213
from pytorch_lightning.core.lightning import LightningModule
1314
from pytorch_lightning.core.memory import ModelSummary
1415
from pytorch_lightning.loggers import LightningLoggerBase
@@ -890,7 +891,8 @@ def fit(
890891
self,
891892
model: LightningModule,
892893
train_dataloader: Optional[DataLoader] = None,
893-
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
894+
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
895+
datamodule: Optional[LightningDataModule] = None
894896
):
895897
r"""
896898
Runs the full optimization routine.
@@ -939,6 +941,7 @@ def fit(
939941

940942
# set up the passed in dataloaders (if needed)
941943
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
944+
self.__attach_datamodule(model, datamodule)
942945

943946
# check that model is configured correctly
944947
self.check_model_configuration(model)
@@ -1111,6 +1114,24 @@ def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=Non
11111114
if test_dataloaders is not None:
11121115
model.test_dataloader = _PatchDataLoader(test_dataloaders)
11131116

1117+
def __attach_datamodule(self, model, datamodule=None):
1118+
1119+
# We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
1120+
datamodule = datamodule or getattr(model, 'datamodule', None)
1121+
1122+
# If we have a datamodule, attach necessary hooks + dataloaders
1123+
if datamodule:
1124+
if self.is_overridden('setup', datamodule):
1125+
model.setup = datamodule.setup
1126+
if self.is_overridden('prepare_data', datamodule):
1127+
model.prepare_data = datamodule.prepare_data
1128+
if self.is_overridden('train_dataloader', datamodule):
1129+
model.train_dataloader = datamodule.train_dataloader
1130+
if self.is_overridden('val_dataloader', datamodule):
1131+
model.val_dataloader = datamodule.val_dataloader
1132+
if self.is_overridden('test_dataloader', datamodule):
1133+
model.test_dataloader = datamodule.test_dataloader
1134+
11141135
def run_pretrain_routine(self, model: LightningModule):
11151136
"""Sanity check a few things before starting actual training.
11161137
@@ -1241,7 +1262,8 @@ def test(
12411262
model: Optional[LightningModule] = None,
12421263
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
12431264
ckpt_path: Optional[str] = 'best',
1244-
verbose: bool = True
1265+
verbose: bool = True,
1266+
datamodule: Optional[LightningDataModule] = None
12451267
):
12461268
r"""
12471269
@@ -1305,6 +1327,9 @@ def test(
13051327
if self.global_rank != 0:
13061328
return
13071329

1330+
# Attach datamodule to get setup/prepare_data added to model before the call to it below
1331+
self.__attach_datamodule(model or self.get_model(), datamodule)
1332+
13081333
self.setup('test')
13091334

13101335
if model is not None:

0 commit comments

Comments
 (0)