Skip to content

Commit 020c332

Browse files
Clean up (#2467)
* Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * Fixes #2455 * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test
1 parent e77add3 commit 020c332

File tree

3 files changed

+78
-3
lines changed

3 files changed

+78
-3
lines changed

pytorch_lightning/callbacks/early_stopping.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,22 @@
99

1010
import numpy as np
1111
import torch
12+
import torch.distributed as dist
1213

1314
from pytorch_lightning import _logger as log
1415
from pytorch_lightning.callbacks.base import Callback
1516
from pytorch_lightning.utilities import rank_zero_warn
1617

1718
torch_inf = torch.tensor(np.Inf)
1819

20+
try:
21+
import torch_xla
22+
import torch_xla.core.xla_model as xm
23+
except ImportError:
24+
XLA_AVAILABLE = False
25+
else:
26+
XLA_AVAILABLE = True
27+
1928

2029
class EarlyStopping(Callback):
2130
r"""
@@ -138,17 +147,38 @@ def _run_early_stopping_check(self, trainer, pl_module):
138147

139148
current = logs.get(self.monitor)
140149
if not isinstance(current, torch.Tensor):
141-
current = torch.tensor(current)
150+
current = torch.tensor(current, device=pl_module.device)
142151

143-
if self.monitor_op(current - self.min_delta, self.best_score):
152+
if self.monitor_op(current - self.min_delta, self.best_score.to(pl_module.device)):
144153
self.best_score = current
145154
self.wait_count = 0
146155
else:
147156
self.wait_count += 1
148-
if self.wait_count >= self.patience:
157+
should_stop = self.wait_count >= self.patience
158+
159+
if bool(should_stop):
149160
self.stopped_epoch = trainer.current_epoch
150161
trainer.should_stop = True
151162

163+
# stop every ddp process if any world process decides to stop
164+
self._stop_distributed_training(trainer, pl_module)
165+
166+
def _stop_distributed_training(self, trainer, pl_module):
167+
168+
# in ddp make sure all processes stop when one is flagged
169+
if trainer.use_ddp or trainer.use_ddp2:
170+
stop = torch.tensor(int(trainer.should_stop), device=pl_module.device)
171+
dist.all_reduce(stop, op=dist.reduce_op.SUM)
172+
dist.barrier()
173+
trainer.should_stop = stop == trainer.world_size
174+
175+
# if trainer.use_tpu:
176+
# stop = torch.tensor(int(trainer.should_stop), device=pl_module.device)
177+
# xm.all_reduce('sum', [stop])
178+
# print(type(stop))
179+
# torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
180+
# trainer.should_stop = stop.item() == trainer.world_size
181+
152182
def on_train_end(self, trainer, pl_module):
153183
if self.stopped_epoch > 0 and self.verbose > 0:
154184
rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'

tests/models/test_gpu.py

+24
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,30 @@ def test_multi_gpu_model(tmpdir, backend):
5858
memory.get_memory_profile('min_max')
5959

6060

61+
@pytest.mark.spawn
62+
@pytest.mark.parametrize("backend", ['dp', 'ddp', 'ddp2'])
63+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
64+
def test_multi_gpu_early_stop(tmpdir, backend):
65+
"""Make sure DDP works. with early stopping"""
66+
tutils.set_random_master_port()
67+
68+
trainer_options = dict(
69+
default_root_dir=tmpdir,
70+
early_stop_callback=True,
71+
max_epochs=50,
72+
limit_train_batches=10,
73+
limit_val_batches=10,
74+
gpus=[0, 1],
75+
distributed_backend=backend,
76+
)
77+
78+
model = EvalModelTemplate()
79+
# tutils.run_model_test(trainer_options, model)
80+
trainer = Trainer(**trainer_options)
81+
result = trainer.fit(model)
82+
assert result
83+
84+
6185
@pytest.mark.spawn
6286
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
6387
def test_ddp_all_dataloaders_passed_to_fit(tmpdir):

tests/models/test_tpu.py

+21
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,27 @@
1919
TPU_AVAILABLE = True
2020

2121

22+
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
23+
@pytest.mark.parametrize(['tpu_cores', 'expected_device'], [
24+
pytest.param([1], 'xla:1'),
25+
pytest.param([8], 'xla:8'),
26+
])
27+
def test_early_stop_checkpoints_on_tpu(tmpdir, tpu_cores, expected_device):
28+
"""Test if single TPU core training works"""
29+
model = EvalModelTemplate()
30+
trainer = Trainer(
31+
early_stop_callback=True,
32+
default_root_dir=tmpdir,
33+
progress_bar_refresh_rate=0,
34+
max_epochs=50,
35+
limit_train_batches=10,
36+
limit_val_batches=10,
37+
tpu_cores=tpu_cores,
38+
)
39+
trainer.fit(model)
40+
assert torch_xla._XLAC._xla_get_default_device() == expected_device
41+
42+
2243
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
2344
@pytest.mark.parametrize(['tpu_cores', 'expected_device'], [
2445
pytest.param([1], 'xla:1'),

0 commit comments

Comments
 (0)