Skip to content

Commit 1d565e1

Browse files
add tests for single scalar return from training (#2587)
* add tests for single scalar return from training * add tests for single scalar return from training * add tests for single scalar return from training * add tests for single scalar return from training * add tests for single scalar return from training
1 parent a34609e commit 1d565e1

File tree

5 files changed

+224
-4
lines changed

5 files changed

+224
-4
lines changed

pytorch_lightning/trainer/logging.py

+11
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,17 @@ def process_output(self, output, train=False):
9898
9999
Separates loss from logging and progress bar metrics
100100
"""
101+
# --------------------------
102+
# handle single scalar only
103+
# --------------------------
104+
# single scalar returned from a xx_step
105+
if isinstance(output, torch.Tensor):
106+
progress_bar_metrics = {}
107+
log_metrics = {}
108+
callback_metrics = {}
109+
hiddens = None
110+
return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens
111+
101112
# ---------------
102113
# EXTRACT CALLBACK KEYS
103114
# ---------------

pytorch_lightning/trainer/training_loop.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,10 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
792792
)
793793

794794
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
795-
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
795+
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
796+
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
797+
else:
798+
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
796799

797800
# accumulate loss
798801
# (if accumulate_grad_batches = 1 no effect)

tests/base/deterministic_model.py

+41
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,47 @@ def count_num_graphs(self, result, num_graphs=0):
5252

5353
return num_graphs
5454

55+
# ---------------------------
56+
# scalar return
57+
# ---------------------------
58+
def training_step_scalar_return(self, batch, batch_idx):
59+
acc = self.step(batch, batch_idx)
60+
self.training_step_called = True
61+
return acc
62+
63+
def training_step_end_scalar(self, output):
64+
self.training_step_end_called = True
65+
66+
# make sure loss has the grad
67+
assert isinstance(output, torch.Tensor)
68+
assert output.grad_fn is not None
69+
70+
# make sure nothing else has grads
71+
assert self.count_num_graphs({'loss': output}) == 1
72+
73+
assert output == 171
74+
75+
return output
76+
77+
def training_epoch_end_scalar(self, outputs):
78+
"""
79+
There should be an array of scalars without graphs that are all 171 (4 of them)
80+
"""
81+
self.training_epoch_end_called = True
82+
83+
if self.use_dp or self.use_ddp2:
84+
pass
85+
else:
86+
# only saw 4 batches
87+
assert len(outputs) == 4
88+
for batch_out in outputs:
89+
assert batch_out == 171
90+
assert batch_out.grad_fn is None
91+
assert isinstance(batch_out, torch.Tensor)
92+
93+
prototype_loss = outputs[0]
94+
return prototype_loss
95+
5596
# --------------------------
5697
# dictionary returns
5798
# --------------------------

tests/trainer/test_trainer_steps.py tests/trainer/test_trainer_steps_dict_return.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
"""
2+
Tests to ensure that the training loop works with a dict
3+
"""
14
from pytorch_lightning import Trainer
25
from tests.base.deterministic_model import DeterministicModel
3-
import pytest
4-
import torch
56

67

7-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
88
def test_training_step_dict(tmpdir):
99
"""
1010
Tests that only training_step can be used
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""
2+
Tests to ensure that the training loop works with a scalar
3+
"""
4+
from pytorch_lightning import Trainer
5+
from tests.base.deterministic_model import DeterministicModel
6+
import torch
7+
8+
9+
def test_training_step_scalar(tmpdir):
10+
"""
11+
Tests that only training_step that returns a single scalar can be used
12+
"""
13+
model = DeterministicModel()
14+
model.training_step = model.training_step_scalar_return
15+
model.val_dataloader = None
16+
17+
trainer = Trainer(
18+
default_root_dir=tmpdir,
19+
fast_dev_run=True,
20+
weights_summary=None,
21+
)
22+
trainer.fit(model)
23+
24+
# make sure correct steps were called
25+
assert model.training_step_called
26+
assert not model.training_step_end_called
27+
assert not model.training_epoch_end_called
28+
29+
# make sure training outputs what is expected
30+
for batch_idx, batch in enumerate(model.train_dataloader()):
31+
break
32+
33+
out = trainer.run_training_batch(batch, batch_idx)
34+
assert out.signal == 0
35+
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
36+
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
37+
38+
train_step_out = out.training_step_output_for_epoch_end
39+
assert isinstance(train_step_out, torch.Tensor)
40+
assert train_step_out.item() == 171
41+
42+
# make sure the optimizer closure returns the correct things
43+
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
44+
assert opt_closure_result['loss'].item() == 171
45+
46+
47+
def training_step_scalar_with_step_end(tmpdir):
48+
"""
49+
Checks train_step with scalar only + training_step_end
50+
"""
51+
model = DeterministicModel()
52+
model.training_step = model.training_step_scalar_return
53+
model.training_step_end = model.training_step_end_scalar
54+
model.val_dataloader = None
55+
56+
trainer = Trainer(fast_dev_run=True, weights_summary=None)
57+
trainer.fit(model)
58+
59+
# make sure correct steps were called
60+
assert model.training_step_called
61+
assert model.training_step_end_called
62+
assert not model.training_epoch_end_called
63+
64+
# make sure training outputs what is expected
65+
for batch_idx, batch in enumerate(model.train_dataloader()):
66+
break
67+
68+
out = trainer.run_training_batch(batch, batch_idx)
69+
assert out.signal == 0
70+
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
71+
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
72+
73+
train_step_out = out.training_step_output_for_epoch_end
74+
assert isinstance(train_step_out, torch.Tensor)
75+
assert train_step_out.item() == 171
76+
77+
# make sure the optimizer closure returns the correct things
78+
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
79+
assert opt_closure_result['loss'].item() == 171
80+
81+
82+
def test_full_training_loop_scalar(tmpdir):
83+
"""
84+
Checks train_step + training_step_end + training_epoch_end
85+
(all with scalar return from train_step)
86+
"""
87+
model = DeterministicModel()
88+
model.training_step = model.training_step_scalar_return
89+
model.training_step_end = model.training_step_end_scalar
90+
model.training_epoch_end = model.training_epoch_end_scalar
91+
model.val_dataloader = None
92+
93+
trainer = Trainer(
94+
default_root_dir=tmpdir,
95+
max_epochs=1,
96+
weights_summary=None,
97+
)
98+
trainer.fit(model)
99+
100+
# make sure correct steps were called
101+
assert model.training_step_called
102+
assert model.training_step_end_called
103+
assert model.training_epoch_end_called
104+
105+
# assert epoch end metrics were added
106+
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
107+
assert len(trainer.progress_bar_metrics) == 0
108+
109+
# make sure training outputs what is expected
110+
for batch_idx, batch in enumerate(model.train_dataloader()):
111+
break
112+
113+
out = trainer.run_training_batch(batch, batch_idx)
114+
assert out.signal == 0
115+
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
116+
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
117+
118+
train_step_out = out.training_step_output_for_epoch_end
119+
assert isinstance(train_step_out, torch.Tensor)
120+
assert train_step_out.item() == 171
121+
122+
# make sure the optimizer closure returns the correct things
123+
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
124+
assert opt_closure_result['loss'].item() == 171
125+
126+
127+
def test_train_step_epoch_end_scalar(tmpdir):
128+
"""
129+
Checks train_step + training_epoch_end (NO training_step_end)
130+
(with scalar return)
131+
"""
132+
model = DeterministicModel()
133+
model.training_step = model.training_step_scalar_return
134+
model.training_step_end = None
135+
model.training_epoch_end = model.training_epoch_end_scalar
136+
model.val_dataloader = None
137+
138+
trainer = Trainer(max_epochs=1, weights_summary=None)
139+
trainer.fit(model)
140+
141+
# make sure correct steps were called
142+
assert model.training_step_called
143+
assert not model.training_step_end_called
144+
assert model.training_epoch_end_called
145+
146+
# assert epoch end metrics were added
147+
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
148+
assert len(trainer.progress_bar_metrics) == 0
149+
150+
# make sure training outputs what is expected
151+
for batch_idx, batch in enumerate(model.train_dataloader()):
152+
break
153+
154+
out = trainer.run_training_batch(batch, batch_idx)
155+
assert out.signal == 0
156+
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
157+
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
158+
159+
train_step_out = out.training_step_output_for_epoch_end
160+
assert isinstance(train_step_out, torch.Tensor)
161+
assert train_step_out.item() == 171
162+
163+
# make sure the optimizer closure returns the correct things
164+
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
165+
assert opt_closure_result['loss'].item() == 171

0 commit comments

Comments
 (0)