File tree 1 file changed +9
-6
lines changed
pytorch_lightning/trainer
1 file changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -161,6 +161,13 @@ def scale_batch_size(self,
161
161
algorithm is terminated
162
162
163
163
batch_arg_name: name of the attribute that stores the batch size.
164
+ It is expected that the user has provided a model or datamodule that has a hyperparameter
165
+ with that name. We will look for this attribute name in the following places
166
+
167
+ - `model`
168
+ - `model.hparams`
169
+ - `model.datamodule`
170
+ - `trainer.datamodule` (the datamodule passed to the tune method)
164
171
165
172
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
166
173
or datamodule.
@@ -263,16 +270,12 @@ def _adjust_batch_size(trainer,
263
270
factor : float = 1.0 ,
264
271
value : Optional [int ] = None ,
265
272
desc : str = None ) -> Tuple [int , bool ]:
266
- """ Function for adjusting the batch size. It is expected that the user
267
- has provided a model that has a hparam field called `batch_size` i.e.
268
- `model.hparams.batch_size` should exist. Additionally there can be a
269
- datamodule attached to either Trainer or model, in that case the attribute
270
- also gets updated when present.
273
+ """ Helper function for adjusting the batch size.
271
274
272
275
Args:
273
276
trainer: instance of pytorch_lightning.Trainer
274
277
275
- batch_arg_name: field where batch_size is stored in `model.hparams`
278
+ batch_arg_name: name of the field where batch_size is stored.
276
279
277
280
factor: value which the old batch size is multiplied by to get the
278
281
new batch size
You can’t perform that action at this time.
0 commit comments