Skip to content

Commit 9445c80

Browse files
authored
set device to root gpu (#3042)
1 parent 9f6be96 commit 9445c80

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

pytorch_lightning/accelerators/gpu_backend.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import torch
1516
from pytorch_lightning.core import LightningModule
1617
from pytorch_lightning.utilities import AMPType
1718

@@ -32,6 +33,7 @@ def setup(self, model):
3233
# call setup
3334
self.trainer.call_setup_hook(model)
3435

36+
torch.cuda.set_device(self.trainer.root_gpu)
3537
model.cuda(self.trainer.root_gpu)
3638

3739
# CHOOSE OPTIMIZER

0 commit comments

Comments
 (0)