|
10 | 10 | from pytorch_lightning.trainer.distrib_parts import _parse_gpu_ids, determine_root_gpu_device
|
11 | 11 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
12 | 12 | from tests.base import EvalModelTemplate
|
| 13 | +from torchtext.data import Batch, Dataset, Example, Field, LabelField |
13 | 14 |
|
14 | 15 | PRETEND_N_OF_GPUS = 16
|
15 | 16 |
|
@@ -301,3 +302,32 @@ def to(self, *args, **kwargs):
|
301 | 302 |
|
302 | 303 | batch = trainer.transfer_batch_to_gpu(CustomBatchType())
|
303 | 304 | assert batch.a.type() == 'torch.cuda.FloatTensor'
|
| 305 | + |
| 306 | + # torchtext.data.Batch |
| 307 | + samples = [ |
| 308 | + {'text': 'PyTorch Lightning is awesome!', 'label': 0}, |
| 309 | + {'text': 'Please make it work with torchtext', 'label': 1} |
| 310 | + ] |
| 311 | + |
| 312 | + text_field = Field() |
| 313 | + label_field = LabelField() |
| 314 | + fields = { |
| 315 | + 'text': ('text', text_field), |
| 316 | + 'label': ('label', label_field) |
| 317 | + } |
| 318 | + |
| 319 | + examples = [Example.fromdict(sample, fields) for sample in samples] |
| 320 | + dataset = Dataset( |
| 321 | + examples=examples, |
| 322 | + fields=fields.values() |
| 323 | + ) |
| 324 | + |
| 325 | + # Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first |
| 326 | + text_field.build_vocab(dataset) |
| 327 | + label_field.build_vocab(dataset) |
| 328 | + |
| 329 | + batch = Batch(data=examples, dataset=dataset) |
| 330 | + batch = trainer.transfer_batch_to_gpu(batch, 0) |
| 331 | + |
| 332 | + assert batch.text.type() == 'torch.cuda.LongTensor' |
| 333 | + assert batch.label.type() == 'torch.cuda.LongTensor' |
0 commit comments