Skip to content

Commit 3e56db2

Browse files
authored
Add MPS device (#1197)
* Add MPS device * Fix device setting
1 parent de85c09 commit 3e56db2

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

mnist_rnn/main.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ def main():
9191
help='learning rate (default: 0.1)')
9292
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
9393
help='learning rate step gamma (default: 0.7)')
94-
parser.add_argument('--no-cuda', action='store_true', default=False,
95-
help='disables CUDA training')
94+
parser.add_argument('--cuda', action='store_true', default=False,
95+
help='enables CUDA training')
96+
parser.add_argument('--mps', action="store_true", default=False,
97+
help="enables MPS training")
9698
parser.add_argument('--dry-run', action='store_true', default=False,
9799
help='quickly check a single pass')
98100
parser.add_argument('--seed', type=int, default=1, metavar='S',
@@ -102,13 +104,19 @@ def main():
102104
parser.add_argument('--save-model', action='store_true', default=False,
103105
help='for Saving the current Model')
104106
args = parser.parse_args()
105-
use_cuda = not args.no_cuda and torch.cuda.is_available()
106107

107-
torch.manual_seed(args.seed)
108+
if args.cuda and not args.mps:
109+
device = "cuda"
110+
elif args.mps and not args.cuda:
111+
device = "mps"
112+
else:
113+
device = "cpu"
114+
115+
device = torch.device(device)
108116

109-
device = torch.device("cuda" if use_cuda else "cpu")
117+
torch.manual_seed(args.seed)
110118

111-
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
119+
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
112120
train_loader = torch.utils.data.DataLoader(
113121
datasets.MNIST('../data', train=True, download=True,
114122
transform=transforms.Compose([

0 commit comments

Comments
 (0)