@@ -44,43 +44,42 @@ def rmdir(_dir):
44
44
45
45
46
46
class TestGemLearner (unittest .TestCase ):
47
+ temp_dir = os .path .join ("tests" ,
48
+ "sources" ,
49
+ "tools" ,
50
+ "perception" ,
51
+ "object_detection_2d" ,
52
+ "gem" ,
53
+ "gem_temp" ,
54
+ )
55
+ dataset_location = os .path .join (temp_dir , 'sample_dataset' )
56
+ learners = {}
57
+ model_backbones = ["resnet50" , "mobilenetv2" ]
58
+
47
59
@classmethod
48
60
def setUpClass (cls ):
49
61
print ("\n \n *********************************\n TEST Object Detection GEM Learner\n "
50
62
"*********************************" )
51
- cls .temp_dir = os .path .join ("tests" , "sources" , "tools" ,
52
- "perception" , "object_detection_2d" ,
53
- "gem" , "gem_temp" )
54
-
55
- cls .model_backbone = "resnet50"
63
+ for backbone in cls .model_backbones :
64
+ cls .learners [backbone ] = GemLearner (iters = 1 ,
65
+ temp_path = cls .temp_dir ,
66
+ backbone = backbone ,
67
+ num_classes = 7 ,
68
+ device = DEVICE ,
69
+ )
56
70
57
- cls .learner = GemLearner (iters = 1 ,
58
- temp_path = cls .temp_dir ,
59
- backbone = cls .model_backbone ,
60
- num_classes = 7 ,
61
- device = DEVICE ,
62
- )
63
-
64
- cls .learner .download (mode = 'pretrained_gem' )
71
+ for learner in cls .learners .values ():
72
+ learner .download (mode = 'pretrained_gem' )
65
73
66
74
print ("Model downloaded" , file = sys .stderr )
67
75
68
- cls .learner .download (mode = 'test_data_sample_dataset' )
76
+ cls .learners [ 'resnet50' ] .download (mode = 'test_data_sample_dataset' )
69
77
70
- cls .learner .download (mode = 'test_data_sample_images' )
78
+ cls .learners [ 'resnet50' ] .download (mode = 'test_data_sample_images' )
71
79
72
80
print ("Data downloaded" , file = sys .stderr )
73
- cls .dataset_location = os .path .join (cls .temp_dir ,
74
- 'sample_dataset' ,
75
- )
76
- cls .m1_dataset = ExternalDataset (
77
- cls .dataset_location ,
78
- "coco" ,
79
- )
80
- cls .m2_dataset = ExternalDataset (
81
- cls .dataset_location ,
82
- "coco" ,
83
- )
81
+ cls .m1_dataset = ExternalDataset (cls .dataset_location , "coco" )
82
+ cls .m2_dataset = ExternalDataset (cls .dataset_location , "coco" )
84
83
85
84
@classmethod
86
85
def tearDownClass (cls ):
@@ -99,35 +98,36 @@ def test_fit(self):
99
98
# version)
100
99
warnings .simplefilter ("ignore" , ResourceWarning )
101
100
warnings .simplefilter ("ignore" , DeprecationWarning )
102
- self .learner .model = None
103
- self .learner .ort_session = None
104
-
105
- self .learner .download (mode = 'pretrained_gem' )
106
-
107
- m = list (self .learner .model .parameters ())[0 ].clone ()
108
-
109
- self .learner .fit (
110
- m1_train_edataset = self .m1_dataset ,
111
- m2_train_edataset = self .m2_dataset ,
112
- annotations_folder = 'annotations' ,
113
- m1_train_annotations_file = 'RGB_26May2021_14h19m_coco.json' ,
114
- m2_train_annotations_file = 'Thermal_26May2021_14h19m_coco.json' ,
115
- m1_train_images_folder = 'train/m1' ,
116
- m2_train_images_folder = 'train/m2' ,
117
- out_dir = os .path .join (self .temp_dir , "outputs" ),
118
- trial_dir = os .path .join (self .temp_dir , "trial" ),
119
- logging_path = '' ,
120
- verbose = False ,
121
- m1_val_edataset = self .m1_dataset ,
122
- m2_val_edataset = self .m2_dataset ,
123
- m1_val_annotations_file = 'RGB_26May2021_14h19m_coco.json' ,
124
- m2_val_annotations_file = 'Thermal_26May2021_14h19m_coco.json' ,
125
- m1_val_images_folder = 'val/m1' ,
126
- m2_val_images_folder = 'val/m2' ,
127
- )
128
101
129
- self .assertFalse (torch .equal (m , list (self .learner .model .parameters ())[0 ]),
130
- msg = "Model parameters did not change after running fit." )
102
+ for backbone in self .model_backbones :
103
+ self .learners [backbone ].model = None
104
+ self .learners [backbone ].ort_session = None
105
+
106
+ self .learners [backbone ].download (mode = 'pretrained_gem' )
107
+
108
+ m = list (self .learners [backbone ].model .parameters ())[0 ].clone ()
109
+
110
+ self .learners [backbone ].fit (m1_train_edataset = self .m1_dataset ,
111
+ m2_train_edataset = self .m2_dataset ,
112
+ annotations_folder = 'annotations' ,
113
+ m1_train_annotations_file = 'RGB_26May2021_14h19m_coco.json' ,
114
+ m2_train_annotations_file = 'Thermal_26May2021_14h19m_coco.json' ,
115
+ m1_train_images_folder = 'train/m1' ,
116
+ m2_train_images_folder = 'train/m2' ,
117
+ out_dir = os .path .join (self .temp_dir , "outputs" ),
118
+ trial_dir = os .path .join (self .temp_dir , "trial" ),
119
+ logging_path = '' ,
120
+ verbose = False ,
121
+ m1_val_edataset = self .m1_dataset ,
122
+ m2_val_edataset = self .m2_dataset ,
123
+ m1_val_annotations_file = 'RGB_26May2021_14h19m_coco.json' ,
124
+ m2_val_annotations_file = 'Thermal_26May2021_14h19m_coco.json' ,
125
+ m1_val_images_folder = 'val/m1' ,
126
+ m2_val_images_folder = 'val/m2' ,
127
+ )
128
+
129
+ self .assertFalse (torch .equal (m , list (self .learners [backbone ].model .parameters ())[0 ]),
130
+ msg = "Model parameters did not change after running fit." )
131
131
132
132
# Cleanup
133
133
warnings .simplefilter ("default" , ResourceWarning )
@@ -139,58 +139,60 @@ def test_eval(self):
139
139
# version)
140
140
warnings .simplefilter ("ignore" , ResourceWarning )
141
141
warnings .simplefilter ("ignore" , DeprecationWarning )
142
- self .learner .model = None
143
- self .learner .ort_session = None
144
-
145
- self .learner .download (mode = 'pretrained_gem' )
146
-
147
- result = self .learner .eval (
148
- m1_edataset = self .m1_dataset ,
149
- m2_edataset = self .m2_dataset ,
150
- m1_images_folder = 'val/m1' ,
151
- m2_images_folder = 'val/m2' ,
152
- annotations_folder = 'annotations' ,
153
- m1_annotations_file = 'RGB_26May2021_14h19m_coco.json' ,
154
- m2_annotations_file = 'Thermal_26May2021_14h19m_coco.json' ,
155
- verbose = False ,
156
- )
157
142
158
- self .assertGreater (len (result ), 0 )
143
+ for backbone in self .model_backbones :
144
+ self .learners [backbone ].model = None
145
+ self .learners [backbone ].ort_session = None
146
+
147
+ self .learners [backbone ].download (mode = 'pretrained_gem' )
148
+
149
+ result = self .learners [backbone ].eval (
150
+ m1_edataset = self .m1_dataset ,
151
+ m2_edataset = self .m2_dataset ,
152
+ m1_images_folder = 'val/m1' ,
153
+ m2_images_folder = 'val/m2' ,
154
+ annotations_folder = 'annotations' ,
155
+ m1_annotations_file = 'RGB_26May2021_14h19m_coco.json' ,
156
+ m2_annotations_file = 'Thermal_26May2021_14h19m_coco.json' ,
157
+ verbose = False ,
158
+ )
159
+
160
+ self .assertGreater (len (result ), 0 )
159
161
160
162
# Cleanup
161
163
warnings .simplefilter ("default" , ResourceWarning )
162
164
warnings .simplefilter ("default" , DeprecationWarning )
163
165
164
166
def test_infer (self ):
165
- self .learner .model = None
166
- self .learner .ort_session = None
167
-
168
- self .learner .download (mode = 'pretrained_gem' )
169
-
170
167
m1_image = Image .open (os .path .join (self .temp_dir , "sample_images/rgb/2021_04_22_21_35_47_852516.jpg" ))
171
168
m2_image = Image .open (os .path .join (self .temp_dir , 'sample_images/aligned_infra/2021_04_22_21_35_47_852516.jpg' ))
172
169
173
- result , _ , _ = self .learner .infer (m1_image , m2_image )
174
-
175
- self .assertGreater (len (result ), 0 )
170
+ for backbone in self .model_backbones :
171
+ self .learners [backbone ].model = None
172
+ self .learners [backbone ].ort_session = None
173
+ self .learners [backbone ].download (mode = 'pretrained_gem' )
174
+ result , _ , _ = self .learners [backbone ].infer (m1_image , m2_image )
175
+ self .assertGreater (len (result ), 0 )
176
176
177
177
def test_save (self ):
178
- self .learner .model = None
179
- self .learner .ort_session = None
178
+ backbone = 'resnet50'
179
+ self .learners [backbone ].model = None
180
+ self .learners [backbone ].ort_session = None
180
181
181
182
model_dir = os .path .join (self .temp_dir , "test_model" )
182
183
183
- self .learner .download (mode = 'pretrained_detr' )
184
+ self .learners [ backbone ] .download (mode = 'pretrained_detr' )
184
185
185
- self .learner .save (model_dir )
186
+ self .learners [ backbone ] .save (model_dir )
186
187
187
- starting_param_1 = list (self .learner .model .parameters ())[0 ].clone ()
188
+ starting_param_1 = list (self .learners [ backbone ] .model .parameters ())[0 ].clone ()
188
189
189
190
learner2 = GemLearner (
190
191
iters = 1 ,
191
192
temp_path = self .temp_dir ,
192
193
device = DEVICE ,
193
194
num_classes = 7 ,
195
+ backbone = backbone ,
194
196
)
195
197
learner2 .load (model_dir )
196
198
0 commit comments