Skip to content

Commit a2a50e7

Browse files
williamFalconshubhamagarwal92
authored and
akarnachev
committed
Shubhamagarwal92 master (Lightning-AI#1349)
* SA: for Lightning-AI#958: set torch cuda device when finding root * SA: for Lightning-AI#958: removing root gpu hack in trainer/evaluation_loop * SA: setting torch cuda device * comment line too long * check if root gpu exists or available * Incorporating suggestions on Lightning-AI#1094 * since root gpu returns none instead of -1 for cpu * undo changes * fixed dp memory thing Co-authored-by: Shubham Agarwal <[email protected]>
1 parent 5e63120 commit a2a50e7

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

pytorch_lightning/trainer/distrib_parts.py

+3
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,9 @@ def dp_train(self, model):
526526
if isinstance(device_ids, int):
527527
device_ids = list(range(device_ids))
528528

529+
# set dp device
530+
torch.cuda.set_device(self.root_gpu)
531+
529532
model = LightningDataParallel(model, device_ids=device_ids)
530533

531534
self.run_pretrain_routine(model)

pytorch_lightning/trainer/trainer.py

+1
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def __init__(
389389
self.gpus = gpus
390390
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
391391
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
392+
self.root_device = torch.device("cpu")
392393

393394
# tpu state flags
394395
self.use_tpu = False

0 commit comments

Comments
 (0)