Skip to content

Commit 3e74914

Browse files
committed
add ImageNet example
1 parent 93cf0d3 commit 3e74914

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

bayesian_torch/ao/quantization/quantize.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from typing import Any, List, Optional, Type, Union
3939
from torch import Tensor
4040
from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn
41+
from torch.nn import BatchNorm2d
4142
# import copy
4243

4344
__all__ = [
@@ -140,6 +141,14 @@ def enable_prepare(m):
140141
if callable(prepare):
141142
m._modules[name].prepare()
142143
m._modules[name].dnn_to_bnn_flag=True
144+
elif "BatchNorm2dLayer" in m._modules[name].__class__.__name__: # replace BatchNorm2dLayer with BatchNorm2d in downsample
145+
layer_fn = BatchNorm2d # Get QBNN layer
146+
bn_layer = layer_fn(
147+
num_features=m._modules[name].num_features
148+
)
149+
bn_layer.__dict__.update(m._modules[name].__dict__)
150+
setattr(m, name, bn_layer)
151+
143152

144153

145154
def prepare(model):
@@ -149,7 +158,7 @@ def prepare(model):
149158
3. run torch.quantize.prepare()
150159
"""
151160
qmodel = QuantizableResNet(QuantizableBottleneck, [3, 4, 6, 3])
152-
qmodel.load_state_dict(model.state_dict())
161+
qmodel.load_state_dict(model.module.state_dict())
153162
qmodel.eval()
154163
enable_prepare(qmodel)
155164
qmodel.qconfig = torch.quantization.get_default_qconfig("onednn")

bayesian_torch/layers/batchnorm.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(self,
1515
affine=True,
1616
track_running_stats=True):
1717
super(BatchNorm2dLayer, self).__init__()
18+
self.num_features = num_features
1819
self.eps = eps
1920
self.momentum = momentum
2021
self.affine = affine
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
3+
model=resnet50
4+
mode='test'
5+
val_batch_size=1
6+
num_monte_carlo=1
7+
8+
python examples/main_bayesian_imagenet_bnn2qbnn.py --mode=$mode --val_batch_size=$val_batch_size --num_monte_carlo=$num_monte_carlo ../../datasets

0 commit comments

Comments
 (0)