|
1 | 1 | import pytest
|
| 2 | +import torch |
2 | 3 |
|
3 | 4 | import tests.base.utils as tutils
|
4 | 5 | from pytorch_lightning import Trainer
|
@@ -482,3 +483,41 @@ class CurrentTestModel(
|
482 | 483 | test_percent_check=0.5
|
483 | 484 | )
|
484 | 485 | trainer.fit(model)
|
| 486 | + |
| 487 | + |
| 488 | +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') |
| 489 | +def test_dataloader_reinit_for_subclass(): |
| 490 | + |
| 491 | + class CustomDataLoader(torch.utils.data.DataLoader): |
| 492 | + def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, |
| 493 | + batch_sampler=None, num_workers=0, collate_fn=None, |
| 494 | + pin_memory=False, drop_last=False, timeout=0, |
| 495 | + worker_init_fn=None, dummy_kwarg=None): |
| 496 | + super().__init__(dataset, |
| 497 | + batch_size, |
| 498 | + shuffle, |
| 499 | + sampler, |
| 500 | + batch_sampler, |
| 501 | + num_workers, |
| 502 | + collate_fn, |
| 503 | + pin_memory, |
| 504 | + drop_last, |
| 505 | + timeout, |
| 506 | + worker_init_fn) |
| 507 | + |
| 508 | + self.dummy_kwarg = dummy_kwarg |
| 509 | + |
| 510 | + trainer = Trainer(gpus=[0, 1], |
| 511 | + num_nodes=1, |
| 512 | + distributed_backend='ddp') |
| 513 | + |
| 514 | + class CustomDummyObj: |
| 515 | + sampler = None |
| 516 | + |
| 517 | + result = trainer.auto_add_sampler(CustomDummyObj(), train=True) |
| 518 | + assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" |
| 519 | + |
| 520 | + result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000))), train=True) |
| 521 | + assert isinstance(result, torch.utils.data.DataLoader) |
| 522 | + assert isinstance(result, CustomDataLoader) |
| 523 | + assert hasattr(result, 'dummy_kwarg') |
0 commit comments