Skip to content

Commit c884f68

Browse files
committed
re-use apply_to_collection function for parsing collections
1 parent be3fa7e commit c884f68

File tree

1 file changed

+4
-31
lines changed

1 file changed

+4
-31
lines changed

pytorch_lightning/trainer/distrib_parts.py

+4-31
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@
351351
LightningDistributedDataParallel,
352352
LightningDataParallel,
353353
)
354+
from pytorch_lightning.utilities.apply_func import apply_to_collection
354355
from pytorch_lightning.utilities.exceptions import MisconfigurationException
355356
from pytorch_lightning.utilities.distributed import rank_zero_only
356357

@@ -446,41 +447,13 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: int):
446447
return self.__transfer_data_to_device(batch, device)
447448

448449
def __transfer_data_to_device(self, batch: Any, device: torch.device):
449-
450450
if self.is_overridden('transfer_batch_to_device'):
451451
return self.get_model().transfer_batch_to_device(batch, device)
452452

453-
# base case: object can be directly moved using `to`
454-
if callable(getattr(batch, 'to', None)):
455-
return batch.to(device, non_blocking=True)
456-
457-
# when list
458-
if isinstance(batch, list):
459-
for i, x in enumerate(batch):
460-
batch[i] = self.__transfer_data_to_device(x, device)
461-
return batch
462-
463-
# when tuple
464-
if isinstance(batch, tuple):
465-
# when namedtuple
466-
if hasattr(batch, '_fields'):
467-
elem_type = type(batch)
468-
return elem_type(*(self.__transfer_data_to_device(x, device) for x in batch))
469-
else:
470-
batch = list(batch)
471-
for i, x in enumerate(batch):
472-
batch[i] = self.__transfer_data_to_device(x, device)
473-
return tuple(batch)
474-
475-
# when dict
476-
if isinstance(batch, dict):
477-
for k, v in batch.items():
478-
batch[k] = self.__transfer_data_to_device(v, device)
479-
480-
return batch
453+
def to(tensor):
454+
return tensor.to(device, non_blocking=True)
481455

482-
# nothing matches, return the value as is without transform
483-
return batch
456+
return apply_to_collection(batch, dtype=torch.Tensor, function=to)
484457

485458
def single_gpu_train(self, model):
486459
model.cuda(self.root_gpu)

0 commit comments

Comments
 (0)