Skip to content

Commit f5c7126

Browse files
fix bug in post-training quantization evaluation due to Jit trace.
Signed-off-by: Ranganath Krishnan <[email protected]>
1 parent 93cf0d3 commit f5c7126

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
type=int,
109109
default=10,
110110
)
111-
parser.add_argument("--mode", type=str, required=True, help="train | test | ptq | test_ptq")
111+
parser.add_argument("--mode", type=str, required=True, help="train | test | ptq")
112112

113113
parser.add_argument(
114114
"--num_monte_carlo",
@@ -333,27 +333,30 @@ def main():
333333
'''
334334
model.load_state_dict(checkpoint['state_dict'])
335335

336-
337336
# post-training quantization
338337
model_int8 = quantize(model, calib_loader, args)
339338
model_int8.eval()
340339
model_int8.cpu()
341340

342-
for i, (data, target) in enumerate(calib_loader):
343-
data = data.cpu()
341+
print('Evaluating quantized INT8 model....')
342+
evaluate(args, model_int8, val_loader)
344343

345-
with torch.no_grad():
346-
traced_model = torch.jit.trace(model_int8, data)
347-
traced_model = torch.jit.freeze(traced_model)
344+
#for i, (data, target) in enumerate(calib_loader):
345+
# data = data.cpu()
348346

349-
save_path = os.path.join(
350-
args.save_dir,
351-
'quantized_bayesian_{}_cifar.pth'.format(args.arch))
352-
traced_model.save(save_path)
353-
print('INT8 model checkpoint saved at ', save_path)
354-
print('Evaluating quantized INT8 model....')
355-
evaluate(args, traced_model, val_loader)
347+
#with torch.no_grad():
348+
# traced_model = torch.jit.trace(model_int8, data)
349+
# traced_model = torch.jit.freeze(traced_model)
350+
351+
#save_path = os.path.join(
352+
# args.save_dir,
353+
# 'quantized_bayesian_{}_cifar.pth'.format(args.arch))
354+
#traced_model.save(save_path)
355+
#print('INT8 model checkpoint saved at ', save_path)
356+
#print('Evaluating quantized INT8 model....')
357+
#evaluate(args, traced_model, val_loader)
356358

359+
'''
357360
elif args.mode =='test_ptq':
358361
print('load model...')
359362
if len(args.model_checkpoint) > 0:
@@ -366,7 +369,7 @@ def main():
366369
model_int8 = torch.jit.freeze(model_int8)
367370
print('Evaluating the INT8 model....')
368371
evaluate(args, model_int8, val_loader)
369-
372+
'''
370373

371374
def train(args, train_loader, model, criterion, optimizer, epoch, tb_writer=None):
372375
batch_time = AverageMeter()

0 commit comments

Comments
 (0)