Skip to content

Commit 6efccda

Browse files
awaelchliatee
authored and
atee
committed
Do not pass non_blocking=True if it does not support this argument (Lightning-AI#2910)
* add docs * non blocking only on tensor * changelog * add test case * add test comment * update changelog changelog chlog
1 parent e9ad6db commit 6efccda

File tree

4 files changed

+33
-3
lines changed

4 files changed

+33
-3
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
123123

124124
- Fixed save apex scaler states ([#2828](https://github.com/PyTorchLightning/pytorch-lightning/pull/2828))
125125

126+
- Fixed passing `non_blocking=True` when transferring a batch object that does not support it ([#2910](https://github.com/PyTorchLightning/pytorch-lightning/pull/2910))
127+
126128
## [0.8.5] - 2020-07-09
127129

128130
### Added

pytorch_lightning/core/hooks.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,13 @@ def transfer_batch_to_device(self, batch, device)
309309
Note:
310310
This hook should only transfer the data and not modify it, nor should it move the data to
311311
any other device than the one passed in as argument (unless you know what you are doing).
312-
The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the
313-
batch and determines the target devices.
312+
313+
Note:
314+
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
315+
for your custom batch objects, you need to define your custom
316+
:class:`~torch.nn.parallel.DistributedDataParallel` or
317+
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
318+
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
314319
315320
See Also:
316321
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`

pytorch_lightning/utilities/apply_func.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def batch_to(data):
104104
setattr(device_data, field, device_field)
105105
return device_data
106106

107-
return data.to(device, non_blocking=True)
107+
kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {}
108+
return data.to(device, **kwargs)
108109

109110
return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)

tests/models/test_gpu.py

+22
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import namedtuple
2+
from unittest.mock import patch
23

34
import pytest
45
import torch
@@ -384,3 +385,24 @@ def to(self, *args, **kwargs):
384385

385386
assert batch.text.type() == 'torch.cuda.LongTensor'
386387
assert batch.label.type() == 'torch.cuda.LongTensor'
388+
389+
390+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
391+
def test_non_blocking():
392+
""" Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects. """
393+
trainer = Trainer()
394+
395+
batch = torch.zeros(2, 3)
396+
with patch.object(batch, 'to', wraps=batch.to) as mocked:
397+
trainer.transfer_batch_to_gpu(batch, 0)
398+
mocked.assert_called_with(torch.device('cuda', 0), non_blocking=True)
399+
400+
class BatchObject(object):
401+
402+
def to(self, *args, **kwargs):
403+
pass
404+
405+
batch = BatchObject()
406+
with patch.object(batch, 'to', wraps=batch.to) as mocked:
407+
trainer.transfer_batch_to_gpu(batch, 0)
408+
mocked.assert_called_with(torch.device('cuda', 0))

0 commit comments

Comments
 (0)