Skip to content

Commit 57d5f6e

Browse files
Barrier (#2257)
* remove barriers * remove barriers * remove barriers * remove barriers * remove barriers * remove barriers * remove barriers * remove barriers * remove barriers * remove barriers
1 parent 03ab574 commit 57d5f6e

File tree

6 files changed

+44
-12
lines changed

6 files changed

+44
-12
lines changed

docs/source/trainer.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Trainer
99
:exclude-members:
1010
run_pretrain_routine,
1111
_abc_impl,
12-
_Trainer_set_random_port,
12+
set_random_port,
1313
_Trainer__set_root_gpu,
1414
_Trainer__init_optimizers,
1515
_Trainer__parse_gpu_ids,

pytorch_lightning/core/hooks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def teardown(self, stage: str):
3030
Called at the end of fit and test.
3131
3232
Args:
33-
step: either 'fit' or 'test'
33+
stage: either 'fit' or 'test'
3434
"""
3535

3636
def on_fit_start(self):

pytorch_lightning/trainer/distrib_data_parallel.py

+5
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,11 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
475475
model.trainer = self
476476
model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks)
477477

478+
# call setup after the ddp process has connected
479+
self.setup()
480+
if self.is_function_implemented('setup', model):
481+
model.setup()
482+
478483
# on world_size=0 let everyone know training is starting
479484
if self.is_global_zero:
480485
log.info('-' * 100)

pytorch_lightning/trainer/distrib_parts.py

+19
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):
155155
return move_data_to_device(batch, device)
156156

157157
def single_gpu_train(self, model):
158+
# call setup
159+
self.setup('fit')
160+
if self.is_function_implemented('setup', model):
161+
model.setup('fit')
162+
158163
model.cuda(self.root_gpu)
159164

160165
# CHOOSE OPTIMIZER
@@ -171,6 +176,11 @@ def single_gpu_train(self, model):
171176
self.run_pretrain_routine(model)
172177

173178
def tpu_train(self, tpu_core_idx, model):
179+
# call setup after the ddp process has connected
180+
self.setup('fit')
181+
if self.is_function_implemented('setup', model):
182+
model.setup('fit')
183+
174184
# put model on tpu
175185
self._device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
176186
model.to(self._device)
@@ -205,6 +215,10 @@ def tpu_train(self, tpu_core_idx, model):
205215
self.save_spawn_weights(model)
206216

207217
def dp_train(self, model):
218+
# call setup after the ddp process has connected
219+
self.setup('fit')
220+
if self.is_function_implemented('setup', model):
221+
model.setup('fit')
208222

209223
# CHOOSE OPTIMIZER
210224
# allow for lr schedulers as well
@@ -246,6 +260,11 @@ def dp_train(self, model):
246260
model.forward = model_autocast_original_forward
247261

248262
def horovod_train(self, model):
263+
# call setup after the ddp process has connected
264+
self.setup('fit')
265+
if self.is_function_implemented('setup', model):
266+
model.setup('fit')
267+
249268
if torch.cuda.is_available() and self.on_gpu:
250269
# Horovod: pin GPU to local rank
251270
assert self.root_gpu == hvd.local_rank()

pytorch_lightning/trainer/trainer.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -857,12 +857,6 @@ def fit(
857857
model.prepare_data()
858858
self._is_data_prepared = True
859859

860-
self.barrier('fit_prepare_data')
861-
862-
self.setup('fit')
863-
if self.is_function_implemented('setup', model):
864-
model.setup('fit')
865-
866860
# Run auto batch size scaling
867861
if self.auto_scale_batch_size:
868862
if isinstance(self.auto_scale_batch_size, bool):
@@ -897,19 +891,19 @@ def fit(
897891
self.ddp_train(task, model)
898892

899893
elif self.distributed_backend == 'cpu_ddp':
900-
self._set_random_port
894+
self.set_random_port()
901895
self.model = model
902896
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))
903897

904898
elif self.distributed_backend == 'ddp_spawn':
905-
self._set_random_port
899+
self.set_random_port()
906900
model.share_memory()
907901

908902
# spin up peers
909903
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, ))
910904

911905
elif self.distributed_backend == 'ddp':
912-
self._set_random_port
906+
self.set_random_port()
913907
self.spawn_ddp_children(model)
914908

915909
# 1 gpu or dp option triggers training using DP module
@@ -932,6 +926,9 @@ def fit(
932926
# track for predict
933927
self.model = model
934928

929+
# wait for all prepare data nodes to finish
930+
self.barrier('setup')
931+
935932
# train
936933
if self.tpu_id is not None:
937934
self.tpu_train(self.tpu_id, model)
@@ -948,6 +945,11 @@ def fit(
948945
if self.use_amp:
949946
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
950947

948+
# call setup after the ddp process has connected
949+
self.setup('fit')
950+
if self.is_function_implemented('setup', model):
951+
model.setup('fit')
952+
951953
# CHOOSE OPTIMIZER
952954
# allow for lr schedulers as well
953955
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

pytorch_lightning/trainer/training_loop.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def training_step(self, batch, batch_idx):
153153
import numpy as np
154154
import torch
155155
from torch.utils.data import DataLoader
156+
import torch.distributed as torch_distrib
156157

157158
from pytorch_lightning import _logger as log
158159
from pytorch_lightning.callbacks.base import Callback
@@ -258,7 +259,7 @@ def get_model(self) -> LightningModule:
258259
"""Warning: this is just empty shell for code implemented in other class."""
259260

260261
@abstractmethod
261-
def is_function_implemented(self, *args):
262+
def is_function_implemented(self, *args, **kwargs):
262263
"""Warning: this is just empty shell for code implemented in other class."""
263264

264265
@abstractmethod
@@ -701,6 +702,11 @@ def _get_optimizers_iterable(self):
701702
def run_training_teardown(self):
702703
if hasattr(self, '_teardown_already_run') and self._teardown_already_run:
703704
return
705+
706+
# clean up dist group
707+
if self.use_ddp or self.use_ddp2:
708+
torch_distrib.destroy_process_group()
709+
704710
# Train end events
705711
with self.profiler.profile('on_train_end'):
706712
# callbacks

0 commit comments

Comments
 (0)