-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathpredict.py
30 lines (27 loc) · 899 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import tensorlayerx as tlx
from tensorlayerx.vision.transforms import Compose, Normalize, Resize, ToTensor
from tensorlayerx.vision.utils import load_image
from tlxzoo.vision.image_classification import ImageClassification
if __name__ == '__main__':
# 0: airplane
# 1: automobile
# 2: bird
# 3: cat
# 4: deer
# 5: dog
# 6: frog
# 7: horse
# 8: ship
# 9: truck
model = ImageClassification(
backbone="vgg16", l2_weights=True, num_labels=10)
model.load_weights("./demo/vision/image_classification/vgg/model.npz")
model.set_eval()
image = load_image("./demo/vision/image_classification/vgg/dog.png")
transform = Compose([
Resize((32, 32)),
Normalize(mean=(125.31, 122.95, 113.86), std=(62.99, 62.09, 66.70)),
ToTensor()
])
image = tlx.expand_dims(transform(image), 0)
print(model.predict(image))