108
108
type = int ,
109
109
default = 10 ,
110
110
)
111
- parser .add_argument ("--mode" , type = str , required = True , help = "train | test | ptq | test_ptq " )
111
+ parser .add_argument ("--mode" , type = str , required = True , help = "train | test | ptq" )
112
112
113
113
parser .add_argument (
114
114
"--num_monte_carlo" ,
@@ -333,27 +333,30 @@ def main():
333
333
'''
334
334
model .load_state_dict (checkpoint ['state_dict' ])
335
335
336
-
337
336
# post-training quantization
338
337
model_int8 = quantize (model , calib_loader , args )
339
338
model_int8 .eval ()
340
339
model_int8 .cpu ()
341
340
342
- for i , ( data , target ) in enumerate ( calib_loader ):
343
- data = data . cpu ( )
341
+ print ( 'Evaluating quantized INT8 model....' )
342
+ evaluate ( args , model_int8 , val_loader )
344
343
345
- with torch .no_grad ():
346
- traced_model = torch .jit .trace (model_int8 , data )
347
- traced_model = torch .jit .freeze (traced_model )
344
+ #for i, (data, target) in enumerate(calib_loader):
345
+ # data = data.cpu()
348
346
349
- save_path = os .path .join (
350
- args .save_dir ,
351
- 'quantized_bayesian_{}_cifar.pth' .format (args .arch ))
352
- traced_model .save (save_path )
353
- print ('INT8 model checkpoint saved at ' , save_path )
354
- print ('Evaluating quantized INT8 model....' )
355
- evaluate (args , traced_model , val_loader )
347
+ #with torch.no_grad():
348
+ # traced_model = torch.jit.trace(model_int8, data)
349
+ # traced_model = torch.jit.freeze(traced_model)
350
+
351
+ #save_path = os.path.join(
352
+ # args.save_dir,
353
+ # 'quantized_bayesian_{}_cifar.pth'.format(args.arch))
354
+ #traced_model.save(save_path)
355
+ #print('INT8 model checkpoint saved at ', save_path)
356
+ #print('Evaluating quantized INT8 model....')
357
+ #evaluate(args, traced_model, val_loader)
356
358
359
+ '''
357
360
elif args.mode =='test_ptq':
358
361
print('load model...')
359
362
if len(args.model_checkpoint) > 0:
@@ -366,7 +369,7 @@ def main():
366
369
model_int8 = torch.jit.freeze(model_int8)
367
370
print('Evaluating the INT8 model....')
368
371
evaluate(args, model_int8, val_loader)
369
-
372
+ '''
370
373
371
374
def train (args , train_loader , model , criterion , optimizer , epoch , tb_writer = None ):
372
375
batch_time = AverageMeter ()
0 commit comments