|
| 1 | +# Copyright The PyTorch Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import os |
| 16 | +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn |
| 17 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 18 | +from pytorch_lightning import _logger as log |
| 19 | + |
| 20 | + |
| 21 | +try: |
| 22 | + import torch_xla |
| 23 | + import torch_xla.core.xla_model as xm |
| 24 | + import torch_xla.distributed.xla_multiprocessing as xmp |
| 25 | +except ImportError: |
| 26 | + XLA_AVAILABLE = False |
| 27 | +else: |
| 28 | + XLA_AVAILABLE = True |
| 29 | + |
| 30 | + |
| 31 | +class TPUAccelerator(object): |
| 32 | + |
| 33 | + def __init__(self, trainer): |
| 34 | + self.trainer = trainer |
| 35 | + self.start_method = None |
| 36 | + |
| 37 | + def setup(self): |
| 38 | + rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores') |
| 39 | + |
| 40 | + if not XLA_AVAILABLE: |
| 41 | + raise MisconfigurationException('No TPU devices found.') |
| 42 | + |
| 43 | + # COLAB_GPU is an env var available by default in Colab environments. |
| 44 | + self.start_method = 'fork' if self.trainer.on_colab_kaggle else 'spawn' |
| 45 | + |
| 46 | + def teardown(self): |
| 47 | + |
| 48 | + # when training completes, load the weights back in main process |
| 49 | + self.__load_weights_on_main_process() |
| 50 | + |
| 51 | + def train(self, model): |
| 52 | + self.trainer.model = model |
| 53 | + |
| 54 | + # train |
| 55 | + if self.trainer.tpu_id is not None: |
| 56 | + self.tpu_train_in_process(self.trainer.tpu_id, model) |
| 57 | + else: |
| 58 | + xmp.spawn( |
| 59 | + self.tpu_train_in_process, |
| 60 | + args=(model,), |
| 61 | + nprocs=self.trainer.tpu_cores, |
| 62 | + start_method=self.start_method |
| 63 | + ) |
| 64 | + |
| 65 | + def __load_weights_on_main_process(self): |
| 66 | + model = self.trainer.model |
| 67 | + |
| 68 | + # load weights if not interrupted |
| 69 | + if self.trainer.on_colab_kaggle and not self.trainer.testing: |
| 70 | + self.trainer.load_spawn_weights(model) |
| 71 | + |
| 72 | + self.trainer.model = model |
| 73 | + |
| 74 | + def tpu_train_in_process(self, tpu_core_idx, model): |
| 75 | + """ |
| 76 | + Here we are inside each individual process |
| 77 | + """ |
| 78 | + if not self.trainer.testing: |
| 79 | + self.trainer.setup('fit') |
| 80 | + model.setup('fit') |
| 81 | + |
| 82 | + # setup TPU training |
| 83 | + self.__setup_tpu_training(model) |
| 84 | + |
| 85 | + # Run the pretrain routine |
| 86 | + self.trainer.run_pretrain_routine(model) |
| 87 | + |
| 88 | + # save weights at the end of training |
| 89 | + self.__save_end_of_training_weights(model) |
| 90 | + |
| 91 | + def __save_end_of_training_weights(self, model): |
| 92 | + |
| 93 | + # when training ends on these platforms dump weights to get out of the main process |
| 94 | + if self.trainer.on_colab_kaggle: |
| 95 | + rank_zero_warn('cleaning up... please do not interrupt') |
| 96 | + self.trainer.save_spawn_weights(model) |
| 97 | + |
| 98 | + def __setup_tpu_training(self, model): |
| 99 | + # use the default device from the process |
| 100 | + tpu_device = xm.xla_device() |
| 101 | + |
| 102 | + # if given an ordinal device, use this as the device |
| 103 | + if self.trainer.tpu_id is not None: |
| 104 | + tpu_device = xm.xla_device(self.trainer.tpu_id) |
| 105 | + |
| 106 | + # track the device and move model to it |
| 107 | + self.trainer._device = tpu_device |
| 108 | + model.to(self.trainer._device) |
| 109 | + |
| 110 | + # get the appropriate tpu ranks |
| 111 | + self.trainer.tpu_local_core_rank = xm.get_local_ordinal() |
| 112 | + self.trainer.tpu_global_core_rank = xm.get_ordinal() |
| 113 | + |
| 114 | + # avoid duplicating progress bar |
| 115 | + if self.trainer.tpu_global_core_rank != 0 and self.trainer.progress_bar_callback is not None: |
| 116 | + self.trainer.progress_bar_callback.disable() |
| 117 | + |
| 118 | + self.trainer.global_rank = self.trainer.tpu_local_core_rank |
| 119 | + rank_zero_only.rank = self.trainer.global_rank |
| 120 | + |
| 121 | + # CHOOSE OPTIMIZER |
| 122 | + # allow for lr schedulers as well |
| 123 | + optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) |
| 124 | + self.trainer.optimizers = optimizers |
| 125 | + self.trainer.lr_schedulers = lr_schedulers |
| 126 | + self.trainer.optimizer_frequencies = optimizer_frequencies |
| 127 | + |
| 128 | + # init 16 bit for TPU |
| 129 | + if self.trainer.precision == 16: |
| 130 | + os.environ['XLA_USE_BF16'] = str(1) |
| 131 | + |
| 132 | + log.info(f'INIT TPU local core: {self.trainer.tpu_local_core_rank},' |
| 133 | + f' global rank: {self.trainer.tpu_global_core_rank}') |
0 commit comments