We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9f6be96 commit 9445c80Copy full SHA for 9445c80
pytorch_lightning/accelerators/gpu_backend.py
@@ -12,6 +12,7 @@
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
+import torch
16
from pytorch_lightning.core import LightningModule
17
from pytorch_lightning.utilities import AMPType
18
@@ -32,6 +33,7 @@ def setup(self, model):
32
33
# call setup
34
self.trainer.call_setup_hook(model)
35
36
+ torch.cuda.set_device(self.trainer.root_gpu)
37
model.cuda(self.trainer.root_gpu)
38
39
# CHOOSE OPTIMIZER
0 commit comments