|
20 | 20 | # parser.add_argument('--d-vvs', default=2, type=int, help='ventral depth')
|
21 | 21 | # parser.add_argument('--cache', default=250, type=int, help='cache size')
|
22 | 22 | parser.add_argument('--root', type=str, help='root')
|
| 23 | +parser.add_argument('--baseline', type=bool, help='train baseline?') |
23 | 24 | args = parser.parse_args()
|
24 | 25 |
|
25 | 26 | bottlenecks = [1, 2, 4, 8, 16, 32]
|
26 | 27 |
|
27 | 28 | n_trials = 5
|
28 | 29 |
|
29 |
| -param_grid = ParameterGrid({ |
30 |
| - 'n_bn': bottlenecks, |
31 |
| - 'a': list(range(n_trials)) |
32 |
| -}) |
| 30 | +if args.baseline: |
| 31 | + retina_ch = 64 |
| 32 | + n_bn = 64 |
| 33 | + rep = args.arr |
33 | 34 |
|
34 |
| -params = param_grid[args.arr] |
35 |
| -n_bn = params['n_bn'] |
36 |
| -rep = params['a'] |
| 35 | + model_file = f'resnet50_baseline_{rep}' |
| 36 | +else: |
| 37 | + retina_ch = 32 |
| 38 | + param_grid = ParameterGrid({ |
| 39 | + 'n_bn': bottlenecks, |
| 40 | + 'a': list(range(n_trials)) |
| 41 | + }) |
| 42 | + |
| 43 | + params = param_grid[args.arr] |
| 44 | + n_bn = params['n_bn'] |
| 45 | + rep = params['a'] |
| 46 | + |
| 47 | + model_file = f'resnet50_{n_bn}_{rep}' |
37 | 48 |
|
38 | 49 | # n_bn = bns[args.arr % 6]
|
39 | 50 | # rep = args.arr // 6
|
40 | 51 |
|
41 | 52 | dir = '/scratch/ewah1g13/models/'
|
42 |
| -model_file = f'resnet50_{n_bn}_{rep}' |
| 53 | + |
43 | 54 | # log_file = f'./logs/imagenet/resnet50_{n_bn}_{rep}.csv'
|
44 | 55 |
|
45 | 56 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
73 | 84 | testset = ImageNetHDF5(f'{args.root}/val', transform=test_transform, cache_size=1000)
|
74 | 85 |
|
75 | 86 | # create data loaders
|
76 |
| -trainloader = DataLoader(trainset, batch_size=1024, shuffle=True, pin_memory=True, num_workers=16) |
77 |
| -testloader = DataLoader(testset, batch_size=1024, shuffle=False, pin_memory=True, num_workers=16) |
| 87 | +trainloader = DataLoader(trainset, batch_size=1024, shuffle=True, pin_memory=True, num_workers=64) |
| 88 | +testloader = DataLoader(testset, batch_size=1024, shuffle=False, pin_memory=True, num_workers=64) |
78 | 89 |
|
79 | 90 | # model = ImageNetModel(n_bn, args.d_vvs, n_inch=3)
|
80 |
| -model = RetinalBottleneckModel(n_bn, 'resnet50', n_out=1000, n_inch=3, retina_kernel_size=7) |
| 91 | +model = RetinalBottleneckModel(n_bn, 'resnet50', n_out=1000, n_inch=3, retina_kernel_size=7, retina_ch=retina_ch) |
81 | 92 | model = nn.DataParallel(model)
|
82 | 93 | # print(model)
|
83 | 94 |
|
|
0 commit comments