Skip to content

Commit 8eb152a

Browse files
williamFalconatee
authored and
atee
committed
fix result obj dp auto reduce (Lightning-AI#3013)
* fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * added warning when changing monitor and using results obj
1 parent a03a0bf commit 8eb152a

File tree

6 files changed

+105
-10
lines changed

6 files changed

+105
-10
lines changed

pytorch_lightning/core/step_result.py

+8
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,14 @@ def reduce_across_time(cls, time_outputs):
361361
result['meta'] = meta
362362
return result
363363

364+
def dp_reduce(self):
365+
for k, value in self.items():
366+
if k == 'meta':
367+
continue
368+
if isinstance(value, list):
369+
value = torch.tensor(value)
370+
self[k] = value.mean(dim=-1)
371+
364372
@property
365373
def should_reduce_on_epoch_end(self) -> bool:
366374
return self['meta']['_internal']['_reduce_on_epoch']

pytorch_lightning/trainer/evaluation_loop.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -343,17 +343,20 @@ def _evaluate(
343343
m = 'only EvalResults or dicts are allowed from validation_step'
344344
raise MisconfigurationException(m)
345345

346+
# ------------------
347+
# EVAL STEP END
348+
# ------------------
346349
# on dp / ddp2 might still want to do something with the batch parts
347-
if test_mode:
348-
if self.is_overridden('test_step_end'):
349-
model_ref = self.get_model()
350-
with self.profiler.profile('test_step_end'):
351-
output = model_ref.test_step_end(output)
352-
else:
353-
if self.is_overridden('validation_step_end'):
354-
model_ref = self.get_model()
355-
with self.profiler.profile('validation_step_end'):
356-
output = model_ref.validation_step_end(output)
350+
eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end'
351+
if self.is_overridden(eval_step_end_hook_name):
352+
model_ref = self.get_model()
353+
with self.profiler.profile(eval_step_end_hook_name):
354+
eval_step_end = getattr(model_ref, eval_step_end_hook_name)
355+
output = eval_step_end(output)
356+
357+
elif is_result_obj and (self.use_dp or self.use_ddp2):
358+
# result auto reduce
359+
output.dp_reduce()
357360

358361
# callbacks (on __batch_end)
359362
if test_mode:

pytorch_lightning/trainer/training_loop.py

+5
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,8 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
12211221
else:
12221222
output = self.model.training_step(*args)
12231223

1224+
is_result_obj = isinstance(output, Result)
1225+
12241226
# allow any mode to define training_step_end
12251227
# do something will all the dp outputs (like softmax)
12261228
if self.is_overridden('training_step_end'):
@@ -1229,6 +1231,9 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
12291231
# TODO: modify when using result obj
12301232
output = model_ref.training_step_end(output)
12311233

1234+
elif is_result_obj and (self.use_dp or self.use_ddp2):
1235+
output.dp_reduce()
1236+
12321237
# allow any mode to define training_end
12331238
# TODO: remove in 1.0.0
12341239
if self.is_overridden('training_end'):

tests/base/model_train_steps.py

+22
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,28 @@ def training_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=
7979
self.training_step_called = True
8080
return result
8181

82+
def training_step_result_obj_dp(self, batch, batch_idx, optimizer_idx=None):
83+
# forward pass
84+
x, y = batch
85+
x = x.view(x.size(0), -1)
86+
y_hat = self(x.to(self.device))
87+
88+
# calculate loss
89+
loss_val = self.loss(y.to(y_hat.device), y_hat)
90+
log_val = loss_val
91+
92+
# alternate between tensors and scalars for "log" and "progress_bar"
93+
if batch_idx % 2 == 0:
94+
log_val = log_val.item()
95+
96+
result = TrainResult(loss_val)
97+
result.log('some_val', log_val * log_val, prog_bar=True, logger=False)
98+
result.log('train_some_val', log_val * log_val)
99+
100+
self.training_step_called = True
101+
102+
return result
103+
82104
def training_step_end_full_loop_result_obj_dp(self, result):
83105
"""
84106
Full loop flow train step (result obj + dp)

tests/base/model_valid_steps.py

+22
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,28 @@ def validation_step_result_obj(self, batch, batch_idx, *args, **kwargs):
5252
})
5353
return result
5454

55+
def validation_step_result_obj_dp(self, batch, batch_idx, *args, **kwargs):
56+
x, y = batch
57+
x = x.view(x.size(0), -1)
58+
y_hat = self(x.to(self.device))
59+
60+
y = y.to(y_hat.device)
61+
loss_val = self.loss(y, y_hat)
62+
63+
# acc
64+
labels_hat = torch.argmax(y_hat, dim=1)
65+
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
66+
val_acc = torch.tensor(val_acc).type_as(x)
67+
68+
result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val)
69+
result.log_dict({
70+
'val_loss': loss_val,
71+
'val_acc': val_acc,
72+
})
73+
74+
self.validation_step_called = True
75+
return result
76+
5577
def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
5678
"""
5779
Lightning calls this inside the validation loop

tests/trainer/test_trainer_steps_result_return.py

+35
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,41 @@ def test_full_train_loop_with_results_obj_dp(tmpdir):
535535
assert 'epoch_train_epoch_end_metric' in seen_keys
536536

537537

538+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
539+
def test_loop_steps_only_dp(tmpdir):
540+
os.environ['PL_DEV_DEBUG'] = '1'
541+
542+
batches = 10
543+
epochs = 3
544+
545+
model = EvalModelTemplate()
546+
model.validation_step = None
547+
model.test_step = None
548+
model.training_step = model.training_step_result_obj_dp
549+
model.training_step_end = None
550+
model.training_epoch_end = None
551+
model.validation_step = model.validation_step_result_obj_dp
552+
model.validation_step_end = None
553+
model.validation_epoch_end = None
554+
model.test_dataloader = None
555+
556+
trainer = Trainer(
557+
default_root_dir=tmpdir,
558+
distributed_backend='dp',
559+
gpus=[0, 1],
560+
max_epochs=epochs,
561+
early_stop_callback=True,
562+
row_log_interval=2,
563+
limit_train_batches=batches,
564+
weights_summary=None,
565+
)
566+
567+
trainer.fit(model)
568+
569+
assert model.training_step_called
570+
assert model.validation_step_called
571+
572+
538573
def test_result_map(tmpdir):
539574
result = TrainResult()
540575
result.log_dict({'x1': torch.tensor(1), 'x2': torch.tensor(2)})

0 commit comments

Comments
 (0)