Skip to content

Commit 6d2e0c5

Browse files
committed
Fixes #2455
1 parent bea5171 commit 6d2e0c5

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

pytorch_lightning/callbacks/early_stopping.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -141,25 +141,21 @@ def _run_early_stopping_check(self, trainer, pl_module):
141141
if not isinstance(current, torch.Tensor):
142142
current = torch.tensor(current)
143143

144+
# in ddp, reduce the stopping metric so every process conditions the same
145+
if trainer.use_ddp or trainer.use_ddp2:
146+
print(f'RANK: {trainer.global_rank}, BEFORE: {current}')
147+
current = current.to(pl_module.device)
148+
dist.all_reduce(current, op=dist.reduce_op.MAX)
149+
print(f'RANK: {trainer.global_rank}, AFTER: {current}')
150+
144151
if self.monitor_op(current - self.min_delta, self.best_score):
145152
self.best_score = current
146153
self.wait_count = 0
147154
else:
148155
self.wait_count += 1
149156
should_stop = self.wait_count >= self.patience
150157

151-
# check flag across all GPUs
152-
should_stop = torch.tensor(int(should_stop), device=pl_module.device)
153-
if trainer.use_ddp or trainer.use_ddp2:
154-
print(f'RANK: {trainer.global_rank} REDUCING...')
155-
dist.all_reduce(should_stop, op=dist.reduce_op.MAX)
156-
print(f'RANK: {trainer.global_rank} REDUCED...')
157-
158-
print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}')
159-
160-
# do actual stop
161-
print(f'RANK: {trainer.global_rank}, SHOULD STOP: {should_stop}, EPOCH: {trainer.current_epoch}')
162-
if bool(should_stop.item()):
158+
if bool(should_stop):
163159
print(f'RANK: {trainer.global_rank}, STOPPING...')
164160
self.stopped_epoch = trainer.current_epoch
165161
trainer.should_stop = True

0 commit comments

Comments
 (0)