Skip to content

Commit 8211256

Browse files
awaelchliBorda
andauthored
data transfer model hook (+ refactor) (#1756)
* refactor and added hook variant a variant b add test revert rename add changelog docs * resolve merge duplication * overridden typo * fix test * tpu id * raise if TPU not available * re-use apply_to_collection function for parsing collections * comment * make utility function available to user * documentation * move changelog entry to top * fix tpu transfer call * fix call * remove hardcoded string * improve test * call model hook by default * Apply suggestions from code review * rename utility function Co-authored-by: Jirka Borovec <[email protected]>
1 parent ade3f36 commit 8211256

File tree

9 files changed

+158
-56
lines changed

9 files changed

+158
-56
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222

2323
- Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033))
2424

25+
- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)).
26+
2527
### Changed
2628

2729
- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))

pytorch_lightning/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@
6060
'Trainer',
6161
'LightningModule',
6262
'Callback',
63-
'data_loader'
64-
'seed_everything'
63+
'data_loader',
64+
'seed_everything',
6565
]
6666

6767
# necessary for regular bolts imports. Skip exception since bolts is not always installed

pytorch_lightning/core/hooks.py

+47
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch
44
from torch import Tensor
55
from torch.optim.optimizer import Optimizer
6+
from pytorch_lightning.utilities import move_data_to_device
7+
68

79
try:
810
from apex import amp
@@ -153,3 +155,48 @@ def backward(self, use_amp, loss, optimizer):
153155
scaled_loss.backward()
154156
else:
155157
loss.backward()
158+
159+
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
160+
"""
161+
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
162+
wrapped in a custom data structure.
163+
164+
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
165+
166+
- :class:`torch.Tensor`
167+
- :class:`list`
168+
- :class:`dict`
169+
- :class:`tuple`
170+
- ``torchtext.data.Batch`` (COMING SOON)
171+
172+
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
173+
174+
Example::
175+
176+
def transfer_batch_to_device(self, batch, device)
177+
if isinstance(batch, CustomBatch):
178+
# move all tensors in your custom data structure to the device
179+
batch.samples = batch.samples.to(device)
180+
batch.targets = batch.targets.to(device)
181+
else:
182+
batch = super().transfer_batch_to_device(data, device)
183+
return batch
184+
185+
Args:
186+
batch: A batch of data that needs to be transferred to a new device.
187+
device: The target device as defined in PyTorch.
188+
189+
Returns:
190+
A reference to the data on the new device.
191+
192+
Note:
193+
This hook should only transfer the data and not modify it, nor should it move the data to
194+
any other device than the one passed in as argument (unless you know what you are doing).
195+
The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the
196+
batch and determines the target devices.
197+
198+
See Also:
199+
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
200+
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
201+
"""
202+
return move_data_to_device(batch, device)

pytorch_lightning/trainer/distrib_parts.py

+45-52
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
LightningDistributedDataParallel,
1919
LightningDataParallel,
2020
)
21+
from pytorch_lightning.utilities import move_data_to_device
2122
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2223
from pytorch_lightning.utilities.distributed import rank_zero_only
2324

@@ -99,58 +100,50 @@ def copy_trainer_model_properties(self, model):
99100
m.tpu_local_core_rank = self.tpu_local_core_rank
100101
m.tpu_global_core_rank = self.tpu_global_core_rank
101102

102-
def transfer_batch_to_tpu(self, batch):
103-
return self.__transfer_data_to_device(batch, device='tpu')
104-
105-
def transfer_batch_to_gpu(self, batch, gpu_id):
106-
return self.__transfer_data_to_device(batch, device='gpu', gpu_id=gpu_id)
107-
108-
def __transfer_data_to_device(self, batch, device, gpu_id=None):
109-
if device == 'tpu' and XLA_AVAILABLE:
110-
# base case: object can be directly moved using `to`
111-
if callable(getattr(batch, 'to', None)):
112-
xla_device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
113-
return batch.to(xla_device)
114-
115-
if device == 'gpu':
116-
# base case: object can be directly moved using `cuda` or `to`
117-
if callable(getattr(batch, 'cuda', None)):
118-
# non_blocking will be ignored if tensor is not pinned.
119-
# so we can always set it to True
120-
return batch.cuda(gpu_id, non_blocking=True)
121-
122-
if callable(getattr(batch, 'to', None)):
123-
# non_blocking will be ignored if tensor is not pinned.
124-
# so we can always set it to True
125-
return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
126-
127-
# when list
128-
if isinstance(batch, list):
129-
for i, x in enumerate(batch):
130-
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
131-
return batch
132-
133-
# when tuple
134-
if isinstance(batch, tuple):
135-
# when namedtuple
136-
if hasattr(batch, '_fields'):
137-
elem_type = type(batch)
138-
return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch))
139-
else:
140-
batch = list(batch)
141-
for i, x in enumerate(batch):
142-
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
143-
return tuple(batch)
144-
145-
# when dict
146-
if isinstance(batch, dict):
147-
for k, v in batch.items():
148-
batch[k] = self.__transfer_data_to_device(v, device, gpu_id)
149-
150-
return batch
151-
152-
# nothing matches, return the value as is without transform
153-
return batch
103+
def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None):
104+
"""
105+
Transfers the data to the TPU.
106+
107+
Args:
108+
batch: A tensor or collection of tensors.
109+
tpu_id: The id of the TPU core. If omitted, the first available core is chosen.
110+
111+
Return:
112+
the tensor on the TPU device.
113+
114+
See Also:
115+
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
116+
"""
117+
if not XLA_AVAILABLE:
118+
raise MisconfigurationException(
119+
'Requested to transfer batch to TPU but XLA is not available.'
120+
' Are you sure this machine has TPUs?'
121+
)
122+
device = xm.xla_device(tpu_id)
123+
return self.__transfer_batch_to_device(batch, device)
124+
125+
def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None):
126+
"""
127+
Transfers the data to the GPU.
128+
129+
Args:
130+
batch: A tensor or collection of tensors.
131+
gpu_id: The id of the GPU device. If omitted, the first available GPU is chosen.
132+
133+
Return:
134+
the tensor on the GPU device.
135+
136+
See Also:
137+
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
138+
"""
139+
device = torch.device('cuda', gpu_id)
140+
return self.__transfer_batch_to_device(batch, device)
141+
142+
def __transfer_batch_to_device(self, batch: Any, device: torch.device):
143+
model = self.get_model()
144+
if model is not None:
145+
return model.transfer_batch_to_device(batch, device)
146+
return move_data_to_device(batch, device)
154147

155148
def single_gpu_train(self, model):
156149
model.cuda(self.root_gpu)

pytorch_lightning/trainer/evaluation_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode:
434434

435435
# TPU data transfer
436436
if self.use_tpu:
437-
batch = self.transfer_batch_to_tpu(batch)
437+
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
438438
args[0] = batch
439439

440440
# CPU, TPU or gpu step

pytorch_lightning/trainer/training_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
753753

754754
# TPU support
755755
elif self.use_tpu:
756-
batch = self.transfer_batch_to_tpu(batch)
756+
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
757757
args[0] = batch
758758
output = self.model.training_step(*args)
759759

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
"""General utilities"""
22

33
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
4+
from pytorch_lightning.utilities.apply_func import move_data_to_device

pytorch_lightning/utilities/apply_func.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from collections import Mapping, Sequence
22
from typing import Any, Callable, Union
33

4+
import torch
5+
46

57
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
68
"""
@@ -34,3 +36,24 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
3436

3537
# data is neither of dtype, nor a collection
3638
return data
39+
40+
41+
def move_data_to_device(batch: Any, device: torch.device):
42+
"""
43+
Transfers a collection of tensors to the given device.
44+
45+
Args:
46+
batch: A tensor or collection of tensors. See :func:`apply_to_collection`
47+
for a list of supported collection types.
48+
device: The device to which tensors should be moved
49+
50+
Return:
51+
the same collection but with all contained tensors residing on the new device.
52+
53+
See Also:
54+
- :meth:`torch.Tensor.to`
55+
- :class:`torch.device`
56+
"""
57+
def to(tensor):
58+
return tensor.to(device, non_blocking=True)
59+
return apply_to_collection(batch, dtype=torch.Tensor, function=to)

tests/models/test_hooks.py

+36
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import MagicMock
2+
13
import pytest
24
import torch
35

@@ -68,3 +70,37 @@ def training_epoch_end(self, outputs):
6870
# metrics are kept after each epoch
6971
for i in range(num_epochs):
7072
assert metrics[f'epoch_metric_{i}'] == i
73+
74+
75+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
76+
def test_transfer_batch_hook():
77+
78+
class CustomBatch:
79+
80+
def __init__(self, data):
81+
self.samples = data[0]
82+
self.targets = data[1]
83+
84+
class CurrentTestModel(EvalModelTemplate):
85+
86+
hook_called = False
87+
88+
def transfer_batch_to_device(self, data, device):
89+
self.hook_called = True
90+
if isinstance(data, CustomBatch):
91+
data.samples = data.samples.to(device)
92+
data.targets = data.targets.to(device)
93+
else:
94+
data = super().transfer_batch_to_device(data, device)
95+
return data
96+
97+
model = CurrentTestModel()
98+
batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))
99+
100+
trainer = Trainer()
101+
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
102+
trainer.get_model = MagicMock(return_value=model)
103+
batch_gpu = trainer.transfer_batch_to_gpu(batch, 0)
104+
expected = torch.device('cuda', 0)
105+
assert model.hook_called
106+
assert batch_gpu.samples.device == batch_gpu.targets.device == expected

0 commit comments

Comments
 (0)