Skip to content

Commit caa3016

Browse files
committed
🚧 .
1 parent e912951 commit caa3016

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

pytorch_lightning/core/datamodule.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def __init__(self):
3939
super().__init__()
4040
def prepare_data(self):
4141
# download, split, etc...
42+
# only called on rank 0
43+
def setup(self):
44+
# make assignments here
45+
# called on every process in DDP
4246
def train_dataloader(self):
4347
train_split = Dataset(...)
4448
return DataLoader(train_split)
@@ -72,6 +76,9 @@ def __init__(
7276

7377
@property
7478
def train_transforms(self):
79+
"""
80+
Optional transforms you can apply to train dataset
81+
"""
7582
return self._train_transforms
7683

7784
@train_transforms.setter
@@ -80,6 +87,9 @@ def train_transforms(self, t):
8087

8188
@property
8289
def val_transforms(self):
90+
"""
91+
Optional transforms you can apply to validation dataset
92+
"""
8393
return self._val_transforms
8494

8595
@val_transforms.setter
@@ -88,6 +98,9 @@ def val_transforms(self, t):
8898

8999
@property
90100
def test_transforms(self):
101+
"""
102+
Optional transforms you can apply to test dataset
103+
"""
91104
return self._test_transforms
92105

93106
@test_transforms.setter
@@ -96,9 +109,9 @@ def test_transforms(self, t):
96109

97110
def size(self, dim=None) -> Union[Tuple, int]:
98111
"""
99-
Return the dimension of each input
100-
Either as a tuple or list of tuples
112+
Return the dimension of each input either as a tuple or list of tuples.
101113
"""
114+
102115
if dim is not None:
103116
return self.dims[dim]
104117

@@ -109,20 +122,29 @@ def prepare_data(self, *args, **kwargs):
109122
"""
110123
Use this to download and prepare data.
111124
In distributed (GPU, TPU), this will only be called once.
112-
This is called before requesting the dataloaders:
113-
.. warning:: Do not assign anything to the model in this step since this will only be called on 1 GPU.
125+
.. warning:: Do not assign anything to the datamodule in this step since this will only be called on 1 GPU.
114126
Pseudocode::
115-
model.prepare_data()
116-
model.train_dataloader()
117-
model.val_dataloader()
118-
model.test_dataloader()
127+
dm.prepare_data()
128+
dm.setup()
119129
Example::
120130
def prepare_data(self):
121131
download_imagenet()
122132
clean_imagenet()
123133
cache_imagenet()
124134
"""
125135

136+
@abstractmethod
137+
def setup(self, *args, **kwargs):
138+
"""
139+
Use this to load your data from file, split it, etc. You are safe to make state assignments here.
140+
This hook is called on every process when using DDP.
141+
142+
Example::
143+
def setup(self):
144+
data = load_data(...)
145+
self.train_ds, self.val_ds, self.test_ds = split_data(data)
146+
"""
147+
126148
@abstractmethod
127149
def train_dataloader(self, *args, **kwargs) -> DataLoader:
128150
"""

pytorch_lightning/trainer/trainer.py

+12
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,12 @@ def fit(
939939
if hasattr(model, 'hparams'):
940940
parsing.clean_namespace(model.hparams)
941941

942+
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
943+
if (train_dataloader or val_dataloaders) and datamodule:
944+
raise MisconfigurationException(
945+
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
946+
)
947+
942948
# set up the passed in dataloaders (if needed)
943949
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
944950
self.__attach_datamodule(model, datamodule)
@@ -1323,6 +1329,12 @@ def test(
13231329
if self.global_rank != 0:
13241330
return
13251331

1332+
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
1333+
if test_dataloaders and datamodule:
1334+
raise MisconfigurationException(
1335+
'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
1336+
)
1337+
13261338
# Attach datamodule to get setup/prepare_data added to model before the call to it below
13271339
self.__attach_datamodule(model or self.get_model(), datamodule)
13281340

0 commit comments

Comments
 (0)