@@ -91,8 +91,10 @@ def main():
91
91
help = 'learning rate (default: 0.1)' )
92
92
parser .add_argument ('--gamma' , type = float , default = 0.7 , metavar = 'M' ,
93
93
help = 'learning rate step gamma (default: 0.7)' )
94
- parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
95
- help = 'disables CUDA training' )
94
+ parser .add_argument ('--cuda' , action = 'store_true' , default = False ,
95
+ help = 'enables CUDA training' )
96
+ parser .add_argument ('--mps' , action = "store_true" , default = False ,
97
+ help = "enables MPS training" )
96
98
parser .add_argument ('--dry-run' , action = 'store_true' , default = False ,
97
99
help = 'quickly check a single pass' )
98
100
parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
@@ -102,13 +104,19 @@ def main():
102
104
parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
103
105
help = 'for Saving the current Model' )
104
106
args = parser .parse_args ()
105
- use_cuda = not args .no_cuda and torch .cuda .is_available ()
106
107
107
- torch .manual_seed (args .seed )
108
+ if args .cuda and not args .mps :
109
+ device = "cuda"
110
+ elif args .mps and not args .cuda :
111
+ device = "mps"
112
+ else :
113
+ device = "cpu"
114
+
115
+ device = torch .device (device )
108
116
109
- device = torch .device ( "cuda" if use_cuda else "cpu" )
117
+ torch .manual_seed ( args . seed )
110
118
111
- kwargs = {'num_workers' : 1 , 'pin_memory' : True } if use_cuda else {}
119
+ kwargs = {'num_workers' : 1 , 'pin_memory' : True } if args . cuda else {}
112
120
train_loader = torch .utils .data .DataLoader (
113
121
datasets .MNIST ('../data' , train = True , download = True ,
114
122
transform = transforms .Compose ([
0 commit comments