Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 00e73c1

Browse files
committedSep 6, 2020
update docs
1 parent 0c2a5a2 commit 00e73c1

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed
 

‎pytorch_lightning/trainer/training_tricks.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ def scale_batch_size(self,
161161
algorithm is terminated
162162
163163
batch_arg_name: name of the attribute that stores the batch size.
164+
It is expected that the user has provided a model or datamodule that has a hyperparameter
165+
with that name. We will look for this attribute name in the following places
166+
167+
- `model`
168+
- `model.hparams`
169+
- `model.datamodule`
170+
- `trainer.datamodule` (the datamodule passed to the tune method)
164171
165172
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
166173
or datamodule.
@@ -263,16 +270,12 @@ def _adjust_batch_size(trainer,
263270
factor: float = 1.0,
264271
value: Optional[int] = None,
265272
desc: str = None) -> Tuple[int, bool]:
266-
""" Function for adjusting the batch size. It is expected that the user
267-
has provided a model that has a hparam field called `batch_size` i.e.
268-
`model.hparams.batch_size` should exist. Additionally there can be a
269-
datamodule attached to either Trainer or model, in that case the attribute
270-
also gets updated when present.
273+
""" Helper function for adjusting the batch size.
271274
272275
Args:
273276
trainer: instance of pytorch_lightning.Trainer
274277
275-
batch_arg_name: field where batch_size is stored in `model.hparams`
278+
batch_arg_name: name of the field where batch_size is stored.
276279
277280
factor: value which the old batch size is multiplied by to get the
278281
new batch size

0 commit comments

Comments
 (0)
Please sign in to comment.