-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
45 lines (32 loc) · 1.27 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import argparse
from time import time
import json
import torch
import helper
def get_input_args():
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str, help='Input image')
parser.add_argument('checkpoint', type=str,help='checkpoint to predict')
parser.add_argument('--top_k', type=int, default=5, help='top_k lasses')
parser.add_argument('--gpu', dest='gpu',action='store_true', help='training device')
parser.add_argument('--cat_names', type=str,help='cat to names')
parser.set_defaults(gpu=False)
return parser.parse_args()
def main():
input_args = get_input_args()
gpu = torch.cuda.is_available() and input_args.gpu
print("Predicting on {} using {}".format(
"GPU" if gpu else "CPU", input_args.checkpoint))
model = helper.load_checkpoint(input_args.checkpoint)
if gpu:
model.cuda()
use_mapping_file = False
if input_args.cat_names:
with open(input_args.cat_names, 'r') as f:
cat_to_name = json.load(f)
use_mapping_file = True
probs, classes = helper.predict(input_args.input, model, gpu, input_args.top_k)
for i in range(input_args.top_k):
print("probability of class {}: {}".format(classes[i], probs[i]))
if __name__ == "__main__":
main()