@@ -168,7 +168,7 @@ def forward(self, batch):
168
168
169
169
"""
170
170
171
- def training_step (self , * args , ** kwargs ) -> Union [int , Dict [str , Union [Tensor , Dict [str , Tensor ]]]]:
171
+ def training_step (self , * args , ** kwargs ) -> Union [int , Dict [str , Union [Tensor , Dict [str , Union [ float , Tensor ] ]]]]:
172
172
r"""
173
173
Here you compute and return the training loss and some additional metrics for e.g.
174
174
the progress bar or logger.
@@ -186,8 +186,8 @@ def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, D
186
186
When implementing :meth:`training_step`, return whatever you need in that step:
187
187
188
188
- loss -> tensor scalar **REQUIRED**
189
- - progress_bar -> Dict for progress bar display. Must have only tensors
190
- - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
189
+ - progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars
190
+ - log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc)
191
191
192
192
In this step you'd normally do the forward pass and calculate the loss for a batch.
193
193
You can also do fancier things like multiple forward passes or something model specific.
@@ -202,14 +202,14 @@ def training_step(self, batch, batch_idx):
202
202
out = self(x)
203
203
loss = self.loss(out, x)
204
204
205
- logger_logs = {'training_loss': loss} # optional (MUST ALL BE TENSORS)
205
+ logger_logs = {'training_loss': loss} # optional
206
206
207
207
# if using TestTubeLogger or TensorBoardLogger you can nest scalars
208
- logger_logs = {'losses': logger_logs} # optional (MUST ALL BE TENSORS)
208
+ logger_logs = {'losses': logger_logs} # optional
209
209
210
210
output = {
211
211
'loss': loss, # required
212
- 'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS)
212
+ 'progress_bar': {'training_loss': loss}, # optional
213
213
'log': logger_logs
214
214
}
215
215
@@ -259,8 +259,8 @@ def training_end(self, *args, **kwargs):
259
259
"""
260
260
261
261
def training_epoch_end (
262
- self , outputs : Union [List [Dict [str , Tensor ]], List [List [Dict [str , Tensor ]]]]
263
- ) -> Dict [str , Dict [str , Tensor ]]:
262
+ self , outputs : Union [List [Dict [str , Tensor ]], List [List [Dict [str , Union [ float , Tensor ] ]]]]
263
+ ) -> Dict [str , Dict [str , Union [ float , Tensor ] ]]:
264
264
"""Called at the end of the training epoch with the outputs of all training steps.
265
265
266
266
.. code-block:: python
@@ -334,7 +334,7 @@ def training_epoch_end(self, outputs):
334
334
return results
335
335
"""
336
336
337
- def training_step_end (self , * args , ** kwargs ) -> Dict [str , Union [Tensor , Dict [str , Tensor ]]]:
337
+ def training_step_end (self , * args , ** kwargs ) -> Dict [str , Union [Tensor , Dict [str , Union [ float , Tensor ] ]]]:
338
338
"""
339
339
Use this when training with dp or ddp2 because :meth:`training_step`
340
340
will operate on only part of the batch. However, this is still optional
@@ -358,8 +358,8 @@ def training_step_end(self, *args, **kwargs) -> Dict[str, Union[Tensor, Dict[str
358
358
Dict with loss key and optional log or progress bar keys.
359
359
360
360
- loss -> tensor scalar **REQUIRED**
361
- - progress_bar -> Dict for progress bar display. Must have only tensors
362
- - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
361
+ - progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars
362
+ - log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc)
363
363
364
364
Examples:
365
365
.. code-block:: python
@@ -396,7 +396,7 @@ def training_step_end(self, outputs):
396
396
See the :ref:`multi-gpu-training` guide for more details.
397
397
"""
398
398
399
- def validation_step (self , * args , ** kwargs ) -> Dict [str , Tensor ]:
399
+ def validation_step (self , * args , ** kwargs ) -> Dict [str , Union [ float , Tensor ] ]:
400
400
r"""
401
401
Operates on a single batch of data from the validation set.
402
402
In this step you'd might generate examples or calculate anything of interest like accuracy.
@@ -486,7 +486,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx):
486
486
the model goes back to training mode and gradients are enabled.
487
487
"""
488
488
489
- def validation_step_end (self , * args , ** kwargs ) -> Dict [str , Tensor ]:
489
+ def validation_step_end (self , * args , ** kwargs ) -> Dict [str , Union [ float , Tensor ] ]:
490
490
"""
491
491
Use this when validating with dp or ddp2 because :meth:`validation_step`
492
492
will operate on only part of the batch. However, this is still optional
@@ -553,8 +553,8 @@ def validation_end(self, outputs):
553
553
"""
554
554
555
555
def validation_epoch_end (
556
- self , outputs : Union [List [Dict [str , Tensor ]], List [List [Dict [str , Tensor ]]]]
557
- ) -> Dict [str , Dict [str , Tensor ]]:
556
+ self , outputs : Union [List [Dict [str , Union [ float , Tensor ]]] , List [List [Dict [str , Union [ float , Tensor ] ]]]]
557
+ ) -> Dict [str , Dict [str , Union [ float , Tensor ] ]]:
558
558
"""
559
559
Called at the end of the validation epoch with the outputs of all validation steps.
560
560
@@ -575,8 +575,8 @@ def validation_epoch_end(
575
575
Dict or OrderedDict.
576
576
May have the following optional keys:
577
577
578
- - progress_bar (dict for progress bar display; only tensors)
579
- - log (dict of metrics to add to logger; only tensors).
578
+ - progress_bar (dict for progress bar display; either scalar tensors or Python scalars )
579
+ - log (dict of metrics to add to logger; either scalar tensors or Python scalars ).
580
580
581
581
Note:
582
582
If you didn't define a :meth:`validation_step`, this won't be called.
@@ -630,7 +630,7 @@ def validation_epoch_end(self, outputs):
630
630
return results
631
631
"""
632
632
633
- def test_step (self , * args , ** kwargs ) -> Dict [str , Tensor ]:
633
+ def test_step (self , * args , ** kwargs ) -> Dict [str , Union [ float , Tensor ] ]:
634
634
r"""
635
635
Operates on a single batch of data from the test set.
636
636
In this step you'd normally generate examples or calculate anything of interest
@@ -713,7 +713,7 @@ def test_step(self, batch, batch_idx, dataloader_idx):
713
713
to training mode and gradients are enabled.
714
714
"""
715
715
716
- def test_step_end (self , * args , ** kwargs ) -> Dict [str , Tensor ]:
716
+ def test_step_end (self , * args , ** kwargs ) -> Dict [str , Union [ float , Tensor ] ]:
717
717
"""
718
718
Use this when testing with dp or ddp2 because :meth:`test_step` will operate
719
719
on only part of the batch. However, this is still optional
@@ -779,8 +779,8 @@ def test_end(self, outputs):
779
779
"""
780
780
781
781
def test_epoch_end (
782
- self , outputs : Union [List [Dict [str , Tensor ]], List [List [Dict [str , Tensor ]]]]
783
- ) -> Dict [str , Dict [str , Tensor ]]:
782
+ self , outputs : Union [List [Dict [str , Union [ float , Tensor ]]] , List [List [Dict [str , Union [ float , Tensor ] ]]]]
783
+ ) -> Dict [str , Dict [str , Union [ float , Tensor ] ]]:
784
784
"""
785
785
Called at the end of a test epoch with the output of all test steps.
786
786
@@ -800,8 +800,8 @@ def test_epoch_end(
800
800
Return:
801
801
Dict or OrderedDict: Dict has the following optional keys:
802
802
803
- - progress_bar -> Dict for progress bar display. Must have only tensors.
804
- - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc).
803
+ - progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars .
804
+ - log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc).
805
805
806
806
Note:
807
807
If you didn't define a :meth:`test_step`, this won't be called.
0 commit comments