Skip to content

Commit 1d5607d

Browse files
Fixed failing GPU test.
1 parent 4b8dbb4 commit 1d5607d

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

ml-agents/mlagents/torch_utils/torch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def set_torch_config(torch_settings: TorchSettings) -> None:
5353

5454
if _device.type == "cuda":
5555
torch.set_default_device(_device.type)
56-
torch.set_default_dtype(torch.cuda.FloatTensor)
56+
torch.set_default_dtype(torch.float32)
5757
else:
5858
torch.set_default_dtype(torch.float32)
5959
logger.debug(f"default Torch device: {_device}")

ml-agents/mlagents/trainers/tests/test_torch_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
"device_str, expected_type, expected_index, expected_tensor_type",
1212
[
1313
("cpu", "cpu", None, torch.float32),
14-
("cuda", "cuda", None, torch.cuda.FloatTensor),
15-
("cuda:42", "cuda", 42, torch.cuda.FloatTensor),
14+
("cuda", "cuda", None, torch.float32),
15+
("cuda:42", "cuda", 42, torch.float32),
1616
("opengl", "opengl", None, torch.float32),
1717
],
1818
)

0 commit comments

Comments
 (0)