|
9 | 9 |
|
10 | 10 | import numpy as np
|
11 | 11 | import torch
|
| 12 | +import torch.distributed as dist |
12 | 13 |
|
13 | 14 | from pytorch_lightning import _logger as log
|
14 | 15 | from pytorch_lightning.callbacks.base import Callback
|
15 | 16 | from pytorch_lightning.utilities import rank_zero_warn
|
16 | 17 |
|
17 | 18 | torch_inf = torch.tensor(np.Inf)
|
18 | 19 |
|
| 20 | +try: |
| 21 | + import torch_xla |
| 22 | + import torch_xla.core.xla_model as xm |
| 23 | +except ImportError: |
| 24 | + XLA_AVAILABLE = False |
| 25 | +else: |
| 26 | + XLA_AVAILABLE = True |
| 27 | + |
19 | 28 |
|
20 | 29 | class EarlyStopping(Callback):
|
21 | 30 | r"""
|
@@ -138,17 +147,38 @@ def _run_early_stopping_check(self, trainer, pl_module):
|
138 | 147 |
|
139 | 148 | current = logs.get(self.monitor)
|
140 | 149 | if not isinstance(current, torch.Tensor):
|
141 |
| - current = torch.tensor(current) |
| 150 | + current = torch.tensor(current, device=pl_module.device) |
142 | 151 |
|
143 |
| - if self.monitor_op(current - self.min_delta, self.best_score): |
| 152 | + if self.monitor_op(current - self.min_delta, self.best_score.to(pl_module.device)): |
144 | 153 | self.best_score = current
|
145 | 154 | self.wait_count = 0
|
146 | 155 | else:
|
147 | 156 | self.wait_count += 1
|
148 |
| - if self.wait_count >= self.patience: |
| 157 | + should_stop = self.wait_count >= self.patience |
| 158 | + |
| 159 | + if bool(should_stop): |
149 | 160 | self.stopped_epoch = trainer.current_epoch
|
150 | 161 | trainer.should_stop = True
|
151 | 162 |
|
| 163 | + # stop every ddp process if any world process decides to stop |
| 164 | + self._stop_distributed_training(trainer, pl_module) |
| 165 | + |
| 166 | + def _stop_distributed_training(self, trainer, pl_module): |
| 167 | + |
| 168 | + # in ddp make sure all processes stop when one is flagged |
| 169 | + if trainer.use_ddp or trainer.use_ddp2: |
| 170 | + stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) |
| 171 | + dist.all_reduce(stop, op=dist.reduce_op.SUM) |
| 172 | + dist.barrier() |
| 173 | + trainer.should_stop = stop == trainer.world_size |
| 174 | + |
| 175 | + # if trainer.use_tpu: |
| 176 | + # stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) |
| 177 | + # xm.all_reduce('sum', [stop]) |
| 178 | + # print(type(stop)) |
| 179 | + # torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") |
| 180 | + # trainer.should_stop = stop.item() == trainer.world_size |
| 181 | + |
152 | 182 | def on_train_end(self, trainer, pl_module):
|
153 | 183 | if self.stopped_epoch > 0 and self.verbose > 0:
|
154 | 184 | rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
|
|
0 commit comments