@@ -39,6 +39,10 @@ def __init__(self):
39
39
super().__init__()
40
40
def prepare_data(self):
41
41
# download, split, etc...
42
+ # only called on rank 0
43
+ def setup(self):
44
+ # make assignments here
45
+ # called on every process in DDP
42
46
def train_dataloader(self):
43
47
train_split = Dataset(...)
44
48
return DataLoader(train_split)
@@ -72,6 +76,9 @@ def __init__(
72
76
73
77
@property
74
78
def train_transforms (self ):
79
+ """
80
+ Optional transforms you can apply to train dataset
81
+ """
75
82
return self ._train_transforms
76
83
77
84
@train_transforms .setter
@@ -80,6 +87,9 @@ def train_transforms(self, t):
80
87
81
88
@property
82
89
def val_transforms (self ):
90
+ """
91
+ Optional transforms you can apply to validation dataset
92
+ """
83
93
return self ._val_transforms
84
94
85
95
@val_transforms .setter
@@ -88,6 +98,9 @@ def val_transforms(self, t):
88
98
89
99
@property
90
100
def test_transforms (self ):
101
+ """
102
+ Optional transforms you can apply to test dataset
103
+ """
91
104
return self ._test_transforms
92
105
93
106
@test_transforms .setter
@@ -96,9 +109,9 @@ def test_transforms(self, t):
96
109
97
110
def size (self , dim = None ) -> Union [Tuple , int ]:
98
111
"""
99
- Return the dimension of each input
100
- Either as a tuple or list of tuples
112
+ Return the dimension of each input either as a tuple or list of tuples.
101
113
"""
114
+
102
115
if dim is not None :
103
116
return self .dims [dim ]
104
117
@@ -109,20 +122,29 @@ def prepare_data(self, *args, **kwargs):
109
122
"""
110
123
Use this to download and prepare data.
111
124
In distributed (GPU, TPU), this will only be called once.
112
- This is called before requesting the dataloaders:
113
- .. warning:: Do not assign anything to the model in this step since this will only be called on 1 GPU.
125
+ .. warning:: Do not assign anything to the datamodule in this step since this will only be called on 1 GPU.
114
126
Pseudocode::
115
- model.prepare_data()
116
- model.train_dataloader()
117
- model.val_dataloader()
118
- model.test_dataloader()
127
+ dm.prepare_data()
128
+ dm.setup()
119
129
Example::
120
130
def prepare_data(self):
121
131
download_imagenet()
122
132
clean_imagenet()
123
133
cache_imagenet()
124
134
"""
125
135
136
+ @abstractmethod
137
+ def setup (self , * args , ** kwargs ):
138
+ """
139
+ Use this to load your data from file, split it, etc. You are safe to make state assignments here.
140
+ This hook is called on every process when using DDP.
141
+
142
+ Example::
143
+ def setup(self):
144
+ data = load_data(...)
145
+ self.train_ds, self.val_ds, self.test_ds = split_data(data)
146
+ """
147
+
126
148
@abstractmethod
127
149
def train_dataloader (self , * args , ** kwargs ) -> DataLoader :
128
150
"""
0 commit comments