Skip to content

Commit 0ebfb78

Browse files
Examples: using new API (#1056)
* using new API * typo
1 parent bb7356b commit 0ebfb78

File tree

9 files changed

+16
-17
lines changed

9 files changed

+16
-17
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ To use lightning do 2 things:
145145
y_hat = self.forward(x)
146146
return {'val_loss': F.cross_entropy(y_hat, y)}
147147

148-
def validation_end(self, outputs):
148+
def validation_epoch_end(self, outputs):
149149
# OPTIONAL
150150
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
151151
tensorboard_logs = {'val_loss': avg_loss}
@@ -157,7 +157,7 @@ To use lightning do 2 things:
157157
y_hat = self.forward(x)
158158
return {'test_loss': F.cross_entropy(y_hat, y)}
159159

160-
def test_end(self, outputs):
160+
def test_epoch_end(self, outputs):
161161
# OPTIONAL
162162
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
163163
tensorboard_logs = {'test_loss': avg_loss}
@@ -268,7 +268,7 @@ def validation_step(self, batch, batch_idx):
268268
**And you also decide how to collate the output of all validation steps**
269269

270270
```python
271-
def validation_end(self, outputs):
271+
def validation_epoch_end(self, outputs):
272272
"""
273273
Called at the end of validation to aggregate outputs
274274
:param outputs: list of individual outputs of each validation step

docs/source/early_stopping.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Early stopping
44
Default behavior
55
----------------
66
By default early stopping will be enabled if `'val_loss'`
7-
is found in `validation_end()` return dict. Otherwise
7+
is found in `validation_epoch_end()` return dict. Otherwise
88
training will proceed with early stopping disabled.
99

1010
Enable Early Stopping
@@ -16,7 +16,7 @@ There are two ways to enable early stopping.
1616
.. code-block:: python
1717
1818
# A) Set early_stop_callback to True. Will look for 'val_loss'
19-
# in validation_end() return dict. If it is not found an error is raised.
19+
# in validation_epoch_end() return dict. If it is not found an error is raised.
2020
trainer = Trainer(early_stop_callback=True)
2121
2222
# B) Or configure your own callback

docs/source/experiment_reporting.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ Here we show the validation loss in the progress bar
8787

8888
.. code-block:: python
8989
90-
def validation_end(self, outputs):
90+
def validation_epoch_end(self, outputs):
9191
loss = some_loss()
9292
...
9393

docs/source/introduction_guide.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ sample split in the `train_dataloader` method.
603603
loss = F.nll_loss(logits, y)
604604
return {'val_loss': loss}
605605
606-
def validation_end(self, outputs):
606+
def validation_epoch_end(self, outputs):
607607
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
608608
tensorboard_logs = {'val_loss': avg_loss}
609609
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
@@ -657,7 +657,7 @@ Just like the validation loop, we define exactly the same steps for testing:
657657
loss = F.nll_loss(logits, y)
658658
return {'val_loss': loss}
659659
660-
def test_end(self, outputs):
660+
def test_epoch_end(self, outputs):
661661
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
662662
tensorboard_logs = {'val_loss': avg_loss}
663663
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

pl_examples/basic_examples/lightning_module_template.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def validation_step(self, batch, batch_idx):
143143
# can also return just a scalar instead of a dict (return loss_val)
144144
return output
145145

146-
def validation_end(self, outputs):
146+
def validation_epoch_end(self, outputs):
147147
"""
148148
Called at the end of validation to aggregate outputs
149149
:param outputs: list of individual outputs of each validation step

pl_examples/full_examples/imagenet/imagenet_example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def validation_step(self, batch, batch_idx):
8484

8585
return output
8686

87-
def validation_end(self, outputs):
87+
def validation_epoch_end(self, outputs):
8888

8989
tqdm_dict = {}
9090

tests/models/debug.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def validation_step(self, batch, batch_idx):
3434
y_hat = self.forward(x)
3535
return {'val_loss': self.my_loss(y_hat, y)}
3636

37-
def validation_end(self, outputs):
37+
def validation_epoch_end(self, outputs):
3838
avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
3939
return avg_loss
4040

tests/models/mixins.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import torch
44
from torch import optim
5-
from pytorch_lightning.core.decorators import data_loader
65

76

87
class LightValidationStepMixin:
@@ -64,7 +63,7 @@ class LightValidationMixin(LightValidationStepMixin):
6463
when val_dataloader returns a single dataloader
6564
"""
6665

67-
def validation_end(self, outputs):
66+
def validation_epoch_end(self, outputs):
6867
"""
6968
Called at the end of validation to aggregate outputs
7069
:param outputs: list of individual outputs of each validation step
@@ -163,7 +162,7 @@ class LightValidationMultipleDataloadersMixin(LightValidationStepMultipleDataloa
163162
when val_dataloader returns multiple dataloaders
164163
"""
165164

166-
def validation_end(self, outputs):
165+
def validation_epoch_end(self, outputs):
167166
"""
168167
Called at the end of validation to aggregate outputs
169168
:param outputs: list of individual outputs of each validation step
@@ -271,7 +270,7 @@ def test_step(self, batch, batch_idx, *args, **kwargs):
271270
class LightTestMixin(LightTestStepMixin):
272271
"""Ritch test mixin."""
273272

274-
def test_end(self, outputs):
273+
def test_epoch_end(self, outputs):
275274
"""
276275
Called at the end of validation to aggregate outputs
277276
:param outputs: list of individual outputs of each validation step
@@ -561,7 +560,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx, **kwargs):
561560

562561
class LightTestMultipleDataloadersMixin(LightTestStepMultipleDataloadersMixin):
563562

564-
def test_end(self, outputs):
563+
def test_epoch_end(self, outputs):
565564
"""
566565
Called at the end of validation to aggregate outputs
567566
:param outputs: list of individual outputs of each validation step

tests/trainer/test_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ class LocalModelNoEnd(LightTrainDataloader, LightTestDataloader, LightEmptyTestS
604604
pass
605605

606606
class LocalModelNoStep(LightTrainDataloader, TestModelBase):
607-
def test_end(self, outputs):
607+
def test_epoch_end(self, outputs):
608608
return {}
609609

610610
# Misconfig when neither test_step or test_end is implemented

0 commit comments

Comments
 (0)