|
9 | 9 | from torch.utils.data import DataLoader
|
10 | 10 |
|
11 | 11 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
|
| 12 | +from pytorch_lightning.core.datamodule import LightningDataModule |
12 | 13 | from pytorch_lightning.core.lightning import LightningModule
|
13 | 14 | from pytorch_lightning.core.memory import ModelSummary
|
14 | 15 | from pytorch_lightning.loggers import LightningLoggerBase
|
@@ -890,7 +891,8 @@ def fit(
|
890 | 891 | self,
|
891 | 892 | model: LightningModule,
|
892 | 893 | train_dataloader: Optional[DataLoader] = None,
|
893 |
| - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None |
| 894 | + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, |
| 895 | + datamodule: Optional[LightningDataModule] = None |
894 | 896 | ):
|
895 | 897 | r"""
|
896 | 898 | Runs the full optimization routine.
|
@@ -939,6 +941,7 @@ def fit(
|
939 | 941 |
|
940 | 942 | # set up the passed in dataloaders (if needed)
|
941 | 943 | self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
|
| 944 | + self.__attach_datamodule(model, datamodule) |
942 | 945 |
|
943 | 946 | # check that model is configured correctly
|
944 | 947 | self.check_model_configuration(model)
|
@@ -1111,6 +1114,24 @@ def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=Non
|
1111 | 1114 | if test_dataloaders is not None:
|
1112 | 1115 | model.test_dataloader = _PatchDataLoader(test_dataloaders)
|
1113 | 1116 |
|
| 1117 | + def __attach_datamodule(self, model, datamodule=None): |
| 1118 | + |
| 1119 | + # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it |
| 1120 | + datamodule = datamodule or getattr(model, 'datamodule', None) |
| 1121 | + |
| 1122 | + # If we have a datamodule, attach necessary hooks + dataloaders |
| 1123 | + if datamodule: |
| 1124 | + if self.is_overridden('setup', datamodule): |
| 1125 | + model.setup = datamodule.setup |
| 1126 | + if self.is_overridden('prepare_data', datamodule): |
| 1127 | + model.prepare_data = datamodule.prepare_data |
| 1128 | + if self.is_overridden('train_dataloader', datamodule): |
| 1129 | + model.train_dataloader = datamodule.train_dataloader |
| 1130 | + if self.is_overridden('val_dataloader', datamodule): |
| 1131 | + model.val_dataloader = datamodule.val_dataloader |
| 1132 | + if self.is_overridden('test_dataloader', datamodule): |
| 1133 | + model.test_dataloader = datamodule.test_dataloader |
| 1134 | + |
1114 | 1135 | def run_pretrain_routine(self, model: LightningModule):
|
1115 | 1136 | """Sanity check a few things before starting actual training.
|
1116 | 1137 |
|
@@ -1241,7 +1262,8 @@ def test(
|
1241 | 1262 | model: Optional[LightningModule] = None,
|
1242 | 1263 | test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
|
1243 | 1264 | ckpt_path: Optional[str] = 'best',
|
1244 |
| - verbose: bool = True |
| 1265 | + verbose: bool = True, |
| 1266 | + datamodule: Optional[LightningDataModule] = None |
1245 | 1267 | ):
|
1246 | 1268 | r"""
|
1247 | 1269 |
|
@@ -1305,6 +1327,9 @@ def test(
|
1305 | 1327 | if self.global_rank != 0:
|
1306 | 1328 | return
|
1307 | 1329 |
|
| 1330 | + # Attach datamodule to get setup/prepare_data added to model before the call to it below |
| 1331 | + self.__attach_datamodule(model or self.get_model(), datamodule) |
| 1332 | + |
1308 | 1333 | self.setup('test')
|
1309 | 1334 |
|
1310 | 1335 | if model is not None:
|
|
0 commit comments