diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 393650ccebeec..978eaec554e2a 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -9,7 +9,7 @@ Trainer :exclude-members: run_pretrain_routine, _abc_impl, - _Trainer_set_random_port, + set_random_port, _Trainer__set_root_gpu, _Trainer__init_optimizers, _Trainer__parse_gpu_ids, diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index a4f52711f972e..71b3600ce474f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -30,7 +30,7 @@ def teardown(self, stage: str): Called at the end of fit and test. Args: - step: either 'fit' or 'test' + stage: either 'fit' or 'test' """ def on_fit_start(self): diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 68629dc8c178c..67a3c6d317e42 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -475,6 +475,11 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0): model.trainer = self model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks) + # call setup after the ddp process has connected + self.setup() + if self.is_function_implemented('setup', model): + model.setup() + # on world_size=0 let everyone know training is starting if self.is_global_zero: log.info('-' * 100) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 8807fd1cfc879..145afba576bce 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -155,6 +155,11 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device): return move_data_to_device(batch, device) def single_gpu_train(self, model): + # call setup + self.setup('fit') + if self.is_function_implemented('setup', model): + model.setup('fit') + model.cuda(self.root_gpu) # CHOOSE OPTIMIZER @@ -171,6 +176,11 @@ def single_gpu_train(self, model): self.run_pretrain_routine(model) def tpu_train(self, tpu_core_idx, model): + # call setup after the ddp process has connected + self.setup('fit') + if self.is_function_implemented('setup', model): + model.setup('fit') + # put model on tpu self._device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device() model.to(self._device) @@ -205,6 +215,10 @@ def tpu_train(self, tpu_core_idx, model): self.save_spawn_weights(model) def dp_train(self, model): + # call setup after the ddp process has connected + self.setup('fit') + if self.is_function_implemented('setup', model): + model.setup('fit') # CHOOSE OPTIMIZER # allow for lr schedulers as well @@ -246,6 +260,11 @@ def dp_train(self, model): model.forward = model_autocast_original_forward def horovod_train(self, model): + # call setup after the ddp process has connected + self.setup('fit') + if self.is_function_implemented('setup', model): + model.setup('fit') + if torch.cuda.is_available() and self.on_gpu: # Horovod: pin GPU to local rank assert self.root_gpu == hvd.local_rank() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 36664b55498bf..4783e9ec74bfc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -857,12 +857,6 @@ def fit( model.prepare_data() self._is_data_prepared = True - self.barrier('fit_prepare_data') - - self.setup('fit') - if self.is_function_implemented('setup', model): - model.setup('fit') - # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): @@ -897,19 +891,19 @@ def fit( self.ddp_train(task, model) elif self.distributed_backend == 'cpu_ddp': - self._set_random_port + self.set_random_port() self.model = model mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,)) elif self.distributed_backend == 'ddp_spawn': - self._set_random_port + self.set_random_port() model.share_memory() # spin up peers mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, )) elif self.distributed_backend == 'ddp': - self._set_random_port + self.set_random_port() self.spawn_ddp_children(model) # 1 gpu or dp option triggers training using DP module @@ -932,6 +926,9 @@ def fit( # track for predict self.model = model + # wait for all prepare data nodes to finish + self.barrier('setup') + # train if self.tpu_id is not None: self.tpu_train(self.tpu_id, model) @@ -948,6 +945,11 @@ def fit( if self.use_amp: raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') + # call setup after the ddp process has connected + self.setup('fit') + if self.is_function_implemented('setup', model): + model.setup('fit') + # CHOOSE OPTIMIZER # allow for lr schedulers as well self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 54b8c0271582d..ce805161ce47f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -153,6 +153,7 @@ def training_step(self, batch, batch_idx): import numpy as np import torch from torch.utils.data import DataLoader +import torch.distributed as torch_distrib from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback @@ -258,7 +259,7 @@ def get_model(self) -> LightningModule: """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def is_function_implemented(self, *args): + def is_function_implemented(self, *args, **kwargs): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod @@ -701,6 +702,11 @@ def _get_optimizers_iterable(self): def run_training_teardown(self): if hasattr(self, '_teardown_already_run') and self._teardown_already_run: return + + # clean up dist group + if self.use_ddp or self.use_ddp2: + torch_distrib.destroy_process_group() + # Train end events with self.profiler.profile('on_train_end'): # callbacks