Skip to content

Commit b4f65da

Browse files
williamFalcontullie
authored andcommitted
* training_end renamed to training_step_end * training_end renamed to training_step_end * training_end renamed to training_step_end * training_end renamed to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * fix lost model reference * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end
1 parent 8c46ba1 commit b4f65da

File tree

12 files changed

+1391
-636
lines changed

12 files changed

+1391
-636
lines changed
Loading
Loading

docs/source/experiment_reporting.rst

+5-17
Original file line numberDiff line numberDiff line change
@@ -34,47 +34,35 @@ Log metrics
3434

3535
To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, etc...)
3636

37-
1. Training_end, validation_end, test_end will all log anything in the "log" key of the return dict.
37+
1. training_epoch_end, validation_epoch_end, test_epoch_end will all log anything in the "log" key of the return dict.
3838

3939
.. code-block:: python
4040
41-
def training_end(self, outputs):
41+
def training_epoch_end(self, outputs):
4242
loss = some_loss()
4343
...
4444
4545
logs = {'train_loss': loss}
4646
results = {'log': logs}
4747
return results
4848
49-
def validation_end(self, outputs):
49+
def validation_epoch_end(self, outputs):
5050
loss = some_loss()
5151
...
5252
5353
logs = {'val_loss': loss}
5454
results = {'log': logs}
5555
return results
5656
57-
def test_end(self, outputs):
57+
def test_epoch_end(self, outputs):
5858
loss = some_loss()
5959
...
6060
6161
logs = {'test_loss': loss}
6262
results = {'log': logs}
6363
return results
6464
65-
2. Most of the time, you only need training_step and not training_end. You can also return logs from here:
66-
67-
.. code-block:: python
68-
69-
def training_step(self, batch, batch_idx):
70-
loss = some_loss()
71-
...
72-
73-
logs = {'train_loss': loss}
74-
results = {'log': logs}
75-
return results
76-
77-
3. In addition, you can also use any arbitrary functionality from a particular logger from within your LightningModule.
65+
2. In addition, you can also use any arbitrary functionality from a particular logger from within your LightningModule.
7866
For instance, here we log images using tensorboard.
7967

8068
.. code-block:: python

docs/source/hooks.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Training loop
2626
- on_batch_start
2727
- tbptt_split_batch
2828
- training_step
29-
- training_end (optional)
29+
- training_step_end (optional)
3030
- backward
3131
- on_after_backward
3232
- optimizer.step()

docs/source/multi_gpu.rst

+40-5
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,13 @@ you will only be operating on one of those pieces.
165165
y_0 = batch
166166
167167
For most metrics, this doesn't really matter. However, if you want
168-
full batch statistics or want to use the outputs of the training_step
169-
to do something like a softmax, you can use the `training_end` step.
168+
to add something to your computational graph (like softmax)
169+
using all batch parts you can use the `training_step_end` step.
170170

171171
.. code-block:: python
172172
173-
def training_end(self, outputs):
173+
def training_step_end(self, outputs):
174+
# only use when on dp
174175
outputs = torch.cat(outputs, dim=1)
175176
softmax = softmax(outputs, dim=1)
176177
out = softmax.mean()
@@ -195,9 +196,43 @@ In pseudocode, the full sequence is:
195196
out = gpu_model(batch_split)
196197
all_results.append(out)
197198
198-
# calculate statistics for all parts of the batch
199-
full out = model.training_end(all_results)
199+
# use the full batch for something like softmax
200+
full out = model.training_step_end(all_results)
200201
202+
to illustrate why this is needed, let's look at dataparallel
203+
204+
.. code-block:: python
205+
206+
def training_step(self, batch, batch_idx):
207+
x, y = batch
208+
y_hat = self.forward(batch)
209+
210+
# on dp or ddp2 if we did softmax now it would be wrong
211+
# because batch is actually a piece of the full batch
212+
return y_hat
213+
214+
def training_step_end(self, batch_parts_outputs):
215+
# batch_parts_outputs has outputs of each part of the batch
216+
217+
# do softmax here
218+
outputs = torch.cat(outputs, dim=1)
219+
softmax = softmax(outputs, dim=1)
220+
out = softmax.mean()
221+
222+
return out
223+
224+
If `training_step_end` is defined it will be called regardless of tpu, dp, ddp, etc... which means
225+
it will behave the same no matter the backend.
226+
227+
Validation and test step also have the same option when using dp
228+
229+
.. code-block:: python
230+
231+
def validation_step_end(self, batch_parts_outputs):
232+
...
233+
234+
def test_step_end(self, batch_parts_outputs):
235+
...
201236
202237
Implement Your Own Distributed (DDP) training
203238
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

0 commit comments

Comments
 (0)