diff --git a/.pyrightconfig.json b/.pyrightconfig.json index bb65bfb38141e..97000d69dd29d 100644 --- a/.pyrightconfig.json +++ b/.pyrightconfig.json @@ -7,6 +7,7 @@ "pytorch_lightning/__init__.py", "pytorch_lightning/callbacks", "pytorch_lightning/core", + "pytorch_lightning/accelerators", "pytorch_lightning/loggers", "pytorch_lightning/logging", "pytorch_lightning/metrics", diff --git a/docs/source/conf.py b/docs/source/conf.py index 9c901a1c4ae12..8545c05acf2bb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -138,6 +138,7 @@ exclude_patterns = [ 'api/pytorch_lightning.rst', 'api/pl_examples.*', + 'api/pytorch_lightning.accelerators.*', 'api/modules.rst', 'PULL_REQUEST_TEMPLATE.md', diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py new file mode 100644 index 0000000000000..fbd1ab4adea87 --- /dev/null +++ b/pytorch_lightning/accelerators/__init__.py @@ -0,0 +1 @@ +from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py new file mode 100644 index 0000000000000..5ef421810d4ee --- /dev/null +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -0,0 +1,46 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class GPUAccelerator(object): + + def __init__(self, trainer): + self.trainer = trainer + + def setup(self, model): + # call setup + if not self.trainer.testing: + self.trainer.setup('fit') + model.setup('fit') + + model.cuda(self.trainer.root_gpu) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) + self.trainer.optimizers = optimizers + self.trainer.lr_schedulers = lr_schedulers + self.trainer.optimizer_frequencies = optimizer_frequencies + + # TODO: remove with dropping NVIDIA AMP support + native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") + if self.trainer.use_amp and not native_amp_available: + self._setup_nvidia_apex(model) + + def _setup_nvidia_apex(self, model): + model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) + self.trainer.optimizers = optimizers + self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 953c14a96d330..547f9dc87a605 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect from abc import abstractmethod from argparse import ArgumentParser, Namespace diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index e4749bf4defff..63c6c9b8513ab 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Lightning supports model training on a cluster managed by SLURM in the following cases: diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index bf03514bc2c5a..3115a80f36b77 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Root module for all distributed operations in Lightning. Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU. @@ -165,28 +179,6 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device): return model.transfer_batch_to_device(batch, device) return move_data_to_device(batch, device) - def single_gpu_train(self, model): - # call setup - if not self.testing: - self.setup('fit') - model.setup('fit') - - model.cuda(self.root_gpu) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) - - # TODO: remove with dropping NVIDIA AMP support - if self.use_amp and not NATIVE_AMP_AVALAIBLE: - # An example - model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) - self.optimizers = optimizers - self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers) - - results = self.run_pretrain_routine(model) - return results - def tpu_train(self, tpu_core_idx, model): # call setup after the ddp process has connected if not self.testing: diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 663871d98b419..6aaa0dc663cb8 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC from typing import List, Tuple diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 30449b4f33139..934c06f254f26 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect import os import warnings @@ -37,6 +51,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.configuration_validator import ConfigValidator +from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator # warnings to ignore in trainer warnings.filterwarnings( @@ -646,6 +661,7 @@ def __init__( # tracks internal state for debugging self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) + self.accelerator = None # Callback system self.on_init_end() @@ -1057,7 +1073,9 @@ def fit( results = self.horovod_train(model) elif self.single_gpu: - results = self.single_gpu_train(model) + self.accelerator = GPUAccelerator(self) + self.accelerator.setup(model) + results = self.run_pretrain_routine(model) elif self.use_tpu: # pragma: no-cover rank_zero_info(f'training on {self.tpu_cores} TPU cores') diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 67374cc3e9731..d7a5332272c04 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ The lightning training loop handles everything except the actual computations of your model. To decide what will happen in your training loop, define the `training_step` function. diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 883eed320f80d..1cb9111b8a141 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import sys from abc import ABC, abstractmethod