|
351 | 351 | LightningDistributedDataParallel,
|
352 | 352 | LightningDataParallel,
|
353 | 353 | )
|
| 354 | +from pytorch_lightning.utilities.apply_func import apply_to_collection |
354 | 355 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
355 | 356 | from pytorch_lightning.utilities.distributed import rank_zero_only
|
356 | 357 |
|
@@ -446,41 +447,13 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: int):
|
446 | 447 | return self.__transfer_data_to_device(batch, device)
|
447 | 448 |
|
448 | 449 | def __transfer_data_to_device(self, batch: Any, device: torch.device):
|
449 |
| - |
450 | 450 | if self.is_overridden('transfer_batch_to_device'):
|
451 | 451 | return self.get_model().transfer_batch_to_device(batch, device)
|
452 | 452 |
|
453 |
| - # base case: object can be directly moved using `to` |
454 |
| - if callable(getattr(batch, 'to', None)): |
455 |
| - return batch.to(device, non_blocking=True) |
456 |
| - |
457 |
| - # when list |
458 |
| - if isinstance(batch, list): |
459 |
| - for i, x in enumerate(batch): |
460 |
| - batch[i] = self.__transfer_data_to_device(x, device) |
461 |
| - return batch |
462 |
| - |
463 |
| - # when tuple |
464 |
| - if isinstance(batch, tuple): |
465 |
| - # when namedtuple |
466 |
| - if hasattr(batch, '_fields'): |
467 |
| - elem_type = type(batch) |
468 |
| - return elem_type(*(self.__transfer_data_to_device(x, device) for x in batch)) |
469 |
| - else: |
470 |
| - batch = list(batch) |
471 |
| - for i, x in enumerate(batch): |
472 |
| - batch[i] = self.__transfer_data_to_device(x, device) |
473 |
| - return tuple(batch) |
474 |
| - |
475 |
| - # when dict |
476 |
| - if isinstance(batch, dict): |
477 |
| - for k, v in batch.items(): |
478 |
| - batch[k] = self.__transfer_data_to_device(v, device) |
479 |
| - |
480 |
| - return batch |
| 453 | + def to(tensor): |
| 454 | + return tensor.to(device, non_blocking=True) |
481 | 455 |
|
482 |
| - # nothing matches, return the value as is without transform |
483 |
| - return batch |
| 456 | + return apply_to_collection(batch, dtype=torch.Tensor, function=to) |
484 | 457 |
|
485 | 458 | def single_gpu_train(self, model):
|
486 | 459 | model.cuda(self.root_gpu)
|
|
0 commit comments