Skip to content

Commit e82d9cd

Browse files
authored
Support torchtext on a single GPU (#2379)
* Handle torchtext.data.Batch on GPU * Update CHANGELOG.md * Apply code review requests * Correct the docs * Change requirements
1 parent 73a78a1 commit e82d9cd

File tree

6 files changed

+49
-4
lines changed

6 files changed

+49
-4
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Added TorchText support for moving data to GPU ([#2379](https://github.com/PyTorchLightning/pytorch-lightning/pull/2379))
12+
1113
### Changed
1214

1315
- Changed epoch indexing from 0 instead of 1 ([#2289](https://github.com/PyTorchLightning/pytorch-lightning/pull/2289))

pytorch_lightning/core/hooks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
208208
- :class:`list`
209209
- :class:`dict`
210210
- :class:`tuple`
211-
- ``torchtext.data.Batch`` (COMING SOON)
211+
- :class:`torchtext.data.batch.Batch`
212212
213213
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
214214

pytorch_lightning/utilities/apply_func.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import Any, Callable, Union
44

55
import torch
6+
from torchtext.data import Batch
7+
from copy import copy
68

79

810
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
@@ -84,6 +86,16 @@ def move_data_to_device(batch: Any, device: torch.device):
8486
- :meth:`torch.Tensor.to`
8587
- :class:`torch.device`
8688
"""
87-
def to(data):
89+
90+
def batch_to(data):
91+
if isinstance(data, Batch):
92+
# Shallow copy because each Batch has a reference to Dataset which contains all examples
93+
device_data = copy(data)
94+
for field in data.fields:
95+
# Batch contains output of Field.process(...) which is tensor hence .to(...) exists
96+
device_field = getattr(data, field).to(device, non_blocking=True)
97+
setattr(device_data, field, device_field)
98+
return device_data
99+
88100
return data.to(device, non_blocking=True)
89-
return apply_to_collection(batch, dtype=TransferableDataType, function=to)
101+
return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)

requirements/base.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ tensorboard>=1.14
77
future>=0.17.1 # required for builtins in setup.py
88
# pyyaml>=3.13
99
PyYAML>=5.1 # OmegaConf requirement
10+
torchtext>=0.3.1

requirements/extra.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ matplotlib>=3.1.1
1010
horovod>=0.19.1
1111
omegaconf>=2.0.0
1212
# scipy>=0.13.3
13-
scikit-learn>=0.20.0
13+
scikit-learn>=0.20.0

tests/models/test_gpu.py

+30
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytorch_lightning.trainer.distrib_parts import _parse_gpu_ids, determine_root_gpu_device
1111
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1212
from tests.base import EvalModelTemplate
13+
from torchtext.data import Batch, Dataset, Example, Field, LabelField
1314

1415
PRETEND_N_OF_GPUS = 16
1516

@@ -301,3 +302,32 @@ def to(self, *args, **kwargs):
301302

302303
batch = trainer.transfer_batch_to_gpu(CustomBatchType())
303304
assert batch.a.type() == 'torch.cuda.FloatTensor'
305+
306+
# torchtext.data.Batch
307+
samples = [
308+
{'text': 'PyTorch Lightning is awesome!', 'label': 0},
309+
{'text': 'Please make it work with torchtext', 'label': 1}
310+
]
311+
312+
text_field = Field()
313+
label_field = LabelField()
314+
fields = {
315+
'text': ('text', text_field),
316+
'label': ('label', label_field)
317+
}
318+
319+
examples = [Example.fromdict(sample, fields) for sample in samples]
320+
dataset = Dataset(
321+
examples=examples,
322+
fields=fields.values()
323+
)
324+
325+
# Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first
326+
text_field.build_vocab(dataset)
327+
label_field.build_vocab(dataset)
328+
329+
batch = Batch(data=examples, dataset=dataset)
330+
batch = trainer.transfer_batch_to_gpu(batch, 0)
331+
332+
assert batch.text.type() == 'torch.cuda.LongTensor'
333+
assert batch.label.type() == 'torch.cuda.LongTensor'

0 commit comments

Comments
 (0)