Skip to content

Commit fe91adf

Browse files
authored
Gem Mobilenet-v2 pretrained backbone (#226)
* Add pretrained mobilenetv2 backbone * Update gem.md Co-authored-by: Jelle Luijkx <[email protected]>
1 parent 822dbc7 commit fe91adf

File tree

2 files changed

+86
-84
lines changed
  • docs/reference
  • tests/sources/tools/perception/object_detection_2d/gem

2 files changed

+86
-84
lines changed

Diff for: docs/reference/gem.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ Parameters:
208208
Valid values are: "weights_detr", "pretrained_detr", "pretrained_gem", "test_data_l515" and "test_data_sample_images".
209209
In case of "weights_detr", the weigths for single modal DETR with *resnet50* backbone are downloaded.
210210
In case of "pretrained_detr", the weigths for single modal pretrained DETR with *resnet50* backbone are downloaded.
211-
In case of "pretrained_gem", the weights from *'gem_scavg_e294_mAP0983_rn50_l515_7cls.pth'* (backbone: *'resnet50'*, fusion_method: *'scalar averaged'*, trained on *RGB-Infrared l515_dataset* are downloaded.
211+
In case of "pretrained_gem", the weights (backbone: *'resnet50' or 'mobilenetv2'*, fusion_method: *'scalar averaged'*, trained on *RGB-Infrared l515_dataset*) are downloaded.
212212
In case of "test_data_l515", the *RGB-Infrared l515* dataset is downloaded from the OpenDR server.
213213
In case of "test_data_sample images", two sample images for testing the *infer* function are downloaded.
214214
- **verbose** : *bool, default=False*

Diff for: tests/sources/tools/perception/object_detection_2d/gem/test_gem.py

+85-83
Original file line numberDiff line numberDiff line change
@@ -44,43 +44,42 @@ def rmdir(_dir):
4444

4545

4646
class TestGemLearner(unittest.TestCase):
47+
temp_dir = os.path.join("tests",
48+
"sources",
49+
"tools",
50+
"perception",
51+
"object_detection_2d",
52+
"gem",
53+
"gem_temp",
54+
)
55+
dataset_location = os.path.join(temp_dir, 'sample_dataset')
56+
learners = {}
57+
model_backbones = ["resnet50", "mobilenetv2"]
58+
4759
@classmethod
4860
def setUpClass(cls):
4961
print("\n\n*********************************\nTEST Object Detection GEM Learner\n"
5062
"*********************************")
51-
cls.temp_dir = os.path.join("tests", "sources", "tools",
52-
"perception", "object_detection_2d",
53-
"gem", "gem_temp")
54-
55-
cls.model_backbone = "resnet50"
63+
for backbone in cls.model_backbones:
64+
cls.learners[backbone] = GemLearner(iters=1,
65+
temp_path=cls.temp_dir,
66+
backbone=backbone,
67+
num_classes=7,
68+
device=DEVICE,
69+
)
5670

57-
cls.learner = GemLearner(iters=1,
58-
temp_path=cls.temp_dir,
59-
backbone=cls.model_backbone,
60-
num_classes=7,
61-
device=DEVICE,
62-
)
63-
64-
cls.learner.download(mode='pretrained_gem')
71+
for learner in cls.learners.values():
72+
learner.download(mode='pretrained_gem')
6573

6674
print("Model downloaded", file=sys.stderr)
6775

68-
cls.learner.download(mode='test_data_sample_dataset')
76+
cls.learners['resnet50'].download(mode='test_data_sample_dataset')
6977

70-
cls.learner.download(mode='test_data_sample_images')
78+
cls.learners['resnet50'].download(mode='test_data_sample_images')
7179

7280
print("Data downloaded", file=sys.stderr)
73-
cls.dataset_location = os.path.join(cls.temp_dir,
74-
'sample_dataset',
75-
)
76-
cls.m1_dataset = ExternalDataset(
77-
cls.dataset_location,
78-
"coco",
79-
)
80-
cls.m2_dataset = ExternalDataset(
81-
cls.dataset_location,
82-
"coco",
83-
)
81+
cls.m1_dataset = ExternalDataset(cls.dataset_location, "coco")
82+
cls.m2_dataset = ExternalDataset(cls.dataset_location, "coco")
8483

8584
@classmethod
8685
def tearDownClass(cls):
@@ -99,35 +98,36 @@ def test_fit(self):
9998
# version)
10099
warnings.simplefilter("ignore", ResourceWarning)
101100
warnings.simplefilter("ignore", DeprecationWarning)
102-
self.learner.model = None
103-
self.learner.ort_session = None
104-
105-
self.learner.download(mode='pretrained_gem')
106-
107-
m = list(self.learner.model.parameters())[0].clone()
108-
109-
self.learner.fit(
110-
m1_train_edataset=self.m1_dataset,
111-
m2_train_edataset=self.m2_dataset,
112-
annotations_folder='annotations',
113-
m1_train_annotations_file='RGB_26May2021_14h19m_coco.json',
114-
m2_train_annotations_file='Thermal_26May2021_14h19m_coco.json',
115-
m1_train_images_folder='train/m1',
116-
m2_train_images_folder='train/m2',
117-
out_dir=os.path.join(self.temp_dir, "outputs"),
118-
trial_dir=os.path.join(self.temp_dir, "trial"),
119-
logging_path='',
120-
verbose=False,
121-
m1_val_edataset=self.m1_dataset,
122-
m2_val_edataset=self.m2_dataset,
123-
m1_val_annotations_file='RGB_26May2021_14h19m_coco.json',
124-
m2_val_annotations_file='Thermal_26May2021_14h19m_coco.json',
125-
m1_val_images_folder='val/m1',
126-
m2_val_images_folder='val/m2',
127-
)
128101

129-
self.assertFalse(torch.equal(m, list(self.learner.model.parameters())[0]),
130-
msg="Model parameters did not change after running fit.")
102+
for backbone in self.model_backbones:
103+
self.learners[backbone].model = None
104+
self.learners[backbone].ort_session = None
105+
106+
self.learners[backbone].download(mode='pretrained_gem')
107+
108+
m = list(self.learners[backbone].model.parameters())[0].clone()
109+
110+
self.learners[backbone].fit(m1_train_edataset=self.m1_dataset,
111+
m2_train_edataset=self.m2_dataset,
112+
annotations_folder='annotations',
113+
m1_train_annotations_file='RGB_26May2021_14h19m_coco.json',
114+
m2_train_annotations_file='Thermal_26May2021_14h19m_coco.json',
115+
m1_train_images_folder='train/m1',
116+
m2_train_images_folder='train/m2',
117+
out_dir=os.path.join(self.temp_dir, "outputs"),
118+
trial_dir=os.path.join(self.temp_dir, "trial"),
119+
logging_path='',
120+
verbose=False,
121+
m1_val_edataset=self.m1_dataset,
122+
m2_val_edataset=self.m2_dataset,
123+
m1_val_annotations_file='RGB_26May2021_14h19m_coco.json',
124+
m2_val_annotations_file='Thermal_26May2021_14h19m_coco.json',
125+
m1_val_images_folder='val/m1',
126+
m2_val_images_folder='val/m2',
127+
)
128+
129+
self.assertFalse(torch.equal(m, list(self.learners[backbone].model.parameters())[0]),
130+
msg="Model parameters did not change after running fit.")
131131

132132
# Cleanup
133133
warnings.simplefilter("default", ResourceWarning)
@@ -139,58 +139,60 @@ def test_eval(self):
139139
# version)
140140
warnings.simplefilter("ignore", ResourceWarning)
141141
warnings.simplefilter("ignore", DeprecationWarning)
142-
self.learner.model = None
143-
self.learner.ort_session = None
144-
145-
self.learner.download(mode='pretrained_gem')
146-
147-
result = self.learner.eval(
148-
m1_edataset=self.m1_dataset,
149-
m2_edataset=self.m2_dataset,
150-
m1_images_folder='val/m1',
151-
m2_images_folder='val/m2',
152-
annotations_folder='annotations',
153-
m1_annotations_file='RGB_26May2021_14h19m_coco.json',
154-
m2_annotations_file='Thermal_26May2021_14h19m_coco.json',
155-
verbose=False,
156-
)
157142

158-
self.assertGreater(len(result), 0)
143+
for backbone in self.model_backbones:
144+
self.learners[backbone].model = None
145+
self.learners[backbone].ort_session = None
146+
147+
self.learners[backbone].download(mode='pretrained_gem')
148+
149+
result = self.learners[backbone].eval(
150+
m1_edataset=self.m1_dataset,
151+
m2_edataset=self.m2_dataset,
152+
m1_images_folder='val/m1',
153+
m2_images_folder='val/m2',
154+
annotations_folder='annotations',
155+
m1_annotations_file='RGB_26May2021_14h19m_coco.json',
156+
m2_annotations_file='Thermal_26May2021_14h19m_coco.json',
157+
verbose=False,
158+
)
159+
160+
self.assertGreater(len(result), 0)
159161

160162
# Cleanup
161163
warnings.simplefilter("default", ResourceWarning)
162164
warnings.simplefilter("default", DeprecationWarning)
163165

164166
def test_infer(self):
165-
self.learner.model = None
166-
self.learner.ort_session = None
167-
168-
self.learner.download(mode='pretrained_gem')
169-
170167
m1_image = Image.open(os.path.join(self.temp_dir, "sample_images/rgb/2021_04_22_21_35_47_852516.jpg"))
171168
m2_image = Image.open(os.path.join(self.temp_dir, 'sample_images/aligned_infra/2021_04_22_21_35_47_852516.jpg'))
172169

173-
result, _, _ = self.learner.infer(m1_image, m2_image)
174-
175-
self.assertGreater(len(result), 0)
170+
for backbone in self.model_backbones:
171+
self.learners[backbone].model = None
172+
self.learners[backbone].ort_session = None
173+
self.learners[backbone].download(mode='pretrained_gem')
174+
result, _, _ = self.learners[backbone].infer(m1_image, m2_image)
175+
self.assertGreater(len(result), 0)
176176

177177
def test_save(self):
178-
self.learner.model = None
179-
self.learner.ort_session = None
178+
backbone = 'resnet50'
179+
self.learners[backbone].model = None
180+
self.learners[backbone].ort_session = None
180181

181182
model_dir = os.path.join(self.temp_dir, "test_model")
182183

183-
self.learner.download(mode='pretrained_detr')
184+
self.learners[backbone].download(mode='pretrained_detr')
184185

185-
self.learner.save(model_dir)
186+
self.learners[backbone].save(model_dir)
186187

187-
starting_param_1 = list(self.learner.model.parameters())[0].clone()
188+
starting_param_1 = list(self.learners[backbone].model.parameters())[0].clone()
188189

189190
learner2 = GemLearner(
190191
iters=1,
191192
temp_path=self.temp_dir,
192193
device=DEVICE,
193194
num_classes=7,
195+
backbone=backbone,
194196
)
195197
learner2.load(model_dir)
196198

0 commit comments

Comments
 (0)