Skip to content

Commit b34217e

Browse files
Refactor 2/n (#2708)
* reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator
1 parent e9ed9b7 commit b34217e

File tree

4 files changed

+140
-64
lines changed

4 files changed

+140
-64
lines changed
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
2+
from pytorch_lightning.accelerators.tpu_accelerator import TPUAccelerator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
17+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
18+
from pytorch_lightning import _logger as log
19+
20+
21+
try:
22+
import torch_xla
23+
import torch_xla.core.xla_model as xm
24+
import torch_xla.distributed.xla_multiprocessing as xmp
25+
except ImportError:
26+
XLA_AVAILABLE = False
27+
else:
28+
XLA_AVAILABLE = True
29+
30+
31+
class TPUAccelerator(object):
32+
33+
def __init__(self, trainer):
34+
self.trainer = trainer
35+
self.start_method = None
36+
37+
def setup(self):
38+
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')
39+
40+
if not XLA_AVAILABLE:
41+
raise MisconfigurationException('No TPU devices found.')
42+
43+
# COLAB_GPU is an env var available by default in Colab environments.
44+
self.start_method = 'fork' if self.trainer.on_colab_kaggle else 'spawn'
45+
46+
def teardown(self):
47+
48+
# when training completes, load the weights back in main process
49+
self.__load_weights_on_main_process()
50+
51+
def train(self, model):
52+
self.trainer.model = model
53+
54+
# train
55+
if self.trainer.tpu_id is not None:
56+
self.tpu_train_in_process(self.trainer.tpu_id, model)
57+
else:
58+
xmp.spawn(
59+
self.tpu_train_in_process,
60+
args=(model,),
61+
nprocs=self.trainer.tpu_cores,
62+
start_method=self.start_method
63+
)
64+
65+
def __load_weights_on_main_process(self):
66+
model = self.trainer.model
67+
68+
# load weights if not interrupted
69+
if self.trainer.on_colab_kaggle and not self.trainer.testing:
70+
self.trainer.load_spawn_weights(model)
71+
72+
self.trainer.model = model
73+
74+
def tpu_train_in_process(self, tpu_core_idx, model):
75+
"""
76+
Here we are inside each individual process
77+
"""
78+
if not self.trainer.testing:
79+
self.trainer.setup('fit')
80+
model.setup('fit')
81+
82+
# setup TPU training
83+
self.__setup_tpu_training(model)
84+
85+
# Run the pretrain routine
86+
self.trainer.run_pretrain_routine(model)
87+
88+
# save weights at the end of training
89+
self.__save_end_of_training_weights(model)
90+
91+
def __save_end_of_training_weights(self, model):
92+
93+
# when training ends on these platforms dump weights to get out of the main process
94+
if self.trainer.on_colab_kaggle:
95+
rank_zero_warn('cleaning up... please do not interrupt')
96+
self.trainer.save_spawn_weights(model)
97+
98+
def __setup_tpu_training(self, model):
99+
# use the default device from the process
100+
tpu_device = xm.xla_device()
101+
102+
# if given an ordinal device, use this as the device
103+
if self.trainer.tpu_id is not None:
104+
tpu_device = xm.xla_device(self.trainer.tpu_id)
105+
106+
# track the device and move model to it
107+
self.trainer._device = tpu_device
108+
model.to(self.trainer._device)
109+
110+
# get the appropriate tpu ranks
111+
self.trainer.tpu_local_core_rank = xm.get_local_ordinal()
112+
self.trainer.tpu_global_core_rank = xm.get_ordinal()
113+
114+
# avoid duplicating progress bar
115+
if self.trainer.tpu_global_core_rank != 0 and self.trainer.progress_bar_callback is not None:
116+
self.trainer.progress_bar_callback.disable()
117+
118+
self.trainer.global_rank = self.trainer.tpu_local_core_rank
119+
rank_zero_only.rank = self.trainer.global_rank
120+
121+
# CHOOSE OPTIMIZER
122+
# allow for lr schedulers as well
123+
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
124+
self.trainer.optimizers = optimizers
125+
self.trainer.lr_schedulers = lr_schedulers
126+
self.trainer.optimizer_frequencies = optimizer_frequencies
127+
128+
# init 16 bit for TPU
129+
if self.trainer.precision == 16:
130+
os.environ['XLA_USE_BF16'] = str(1)
131+
132+
log.info(f'INIT TPU local core: {self.trainer.tpu_local_core_rank},'
133+
f' global rank: {self.trainer.tpu_global_core_rank}')

pytorch_lightning/trainer/distrib_parts.py

-40
Original file line numberDiff line numberDiff line change
@@ -179,46 +179,6 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):
179179
return model.transfer_batch_to_device(batch, device)
180180
return move_data_to_device(batch, device)
181181

182-
def tpu_train(self, tpu_core_idx, model):
183-
# call setup after the ddp process has connected
184-
if not self.testing:
185-
self.setup('fit')
186-
model.setup('fit')
187-
188-
# put model on tpu
189-
self._device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
190-
model.to(self._device)
191-
192-
# get the appropriate tpu ranks
193-
self.tpu_local_core_rank = xm.get_local_ordinal()
194-
self.tpu_global_core_rank = xm.get_ordinal()
195-
196-
# avoid duplicating progress bar
197-
if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None:
198-
self.progress_bar_callback.disable()
199-
200-
self.global_rank = self.tpu_local_core_rank
201-
rank_zero_only.rank = self.global_rank
202-
203-
# CHOOSE OPTIMIZER
204-
# allow for lr schedulers as well
205-
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
206-
207-
# init 16 bit for TPU
208-
if self.precision == 16:
209-
os.environ['XLA_USE_BF16'] = str(1)
210-
211-
log.info(f'INIT TPU local core: {self.tpu_local_core_rank},'
212-
f' global rank: {self.tpu_global_core_rank}')
213-
214-
# continue training routine
215-
self.run_pretrain_routine(model)
216-
217-
# when training ends on these platforms dump weights to get out of the main process
218-
if self.on_colab_kaggle:
219-
rank_zero_warn('cleaning up... please do not interrupt')
220-
self.save_spawn_weights(model)
221-
222182
def dp_train(self, model):
223183
# call setup after the ddp process has connected
224184
if not self.testing:

pytorch_lightning/trainer/trainer.py

+6-24
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from pytorch_lightning.utilities.debugging import InternalDebugger
5252
from pytorch_lightning.utilities.exceptions import MisconfigurationException
5353
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
54-
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
54+
from pytorch_lightning.accelerators import GPUAccelerator, TPUAccelerator
5555

5656
# warnings to ignore in trainer
5757
warnings.filterwarnings(
@@ -1077,29 +1077,11 @@ def fit(
10771077
self.accelerator.setup(model)
10781078
results = self.run_pretrain_routine(model)
10791079

1080-
elif self.use_tpu: # pragma: no-cover
1081-
rank_zero_info(f'training on {self.tpu_cores} TPU cores')
1082-
1083-
if not XLA_AVAILABLE:
1084-
raise MisconfigurationException('No TPU devices found.')
1085-
1086-
# COLAB_GPU is an env var available by default in Colab environments.
1087-
start_method = 'fork' if self.on_colab_kaggle else 'spawn'
1088-
1089-
# track for predict
1090-
self.model = model
1091-
1092-
# train
1093-
if self.tpu_id is not None:
1094-
self.tpu_train(self.tpu_id, model)
1095-
else:
1096-
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.tpu_cores, start_method=start_method)
1097-
1098-
# load weights if not interrupted
1099-
if self.on_colab_kaggle and not self.testing:
1100-
self.load_spawn_weights(model)
1101-
1102-
self.model = model
1080+
elif self.use_tpu:
1081+
self.accelerator = TPUAccelerator(self)
1082+
self.accelerator.setup()
1083+
self.accelerator.train(model)
1084+
self.accelerator.teardown()
11031085

11041086
# ON CPU
11051087
else:

0 commit comments

Comments
 (0)