Skip to content

Commit 1af87cf

Browse files
committed
Add baseline option to training
1 parent ede833c commit 1af87cf

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

training/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ class RetinalBottleneckModel(nn.Module):
4343
The first layer must be an nn.Conv2d for the rewriting to work. Similar limitations apply to the final layer,
4444
which needs to be last in the _module dict (or last in the last _module dict entry if there is an nn.Sequential wrapper).
4545
"""
46-
def __init__(self, n_bn, ventral, n_inch=1, n_out=10, init=True, retina_kernel_size=9, transform=None):
46+
def __init__(self, n_bn, ventral, n_inch=1, n_out=10, init=True, retina_kernel_size=9, retina_ch=32, transform=None):
4747
super(RetinalBottleneckModel, self).__init__()
4848

4949
self.transform = transform
5050

5151
self.retina = nn.Sequential()
52-
self.retina.add_module("retina_conv1", nn.Conv2d(n_inch, 32, (retina_kernel_size, retina_kernel_size), padding=retina_kernel_size // 2))
52+
self.retina.add_module("retina_conv1", nn.Conv2d(n_inch, retina_ch, (retina_kernel_size, retina_kernel_size), padding=retina_kernel_size // 2))
5353
self.retina.add_module("retina_relu1", nn.ReLU())
54-
self.retina.add_module("retina_conv2", nn.Conv2d(32, n_bn, (retina_kernel_size, retina_kernel_size), padding=retina_kernel_size // 2))
54+
self.retina.add_module("retina_conv2", nn.Conv2d(retina_ch, n_bn, (retina_kernel_size, retina_kernel_size), padding=retina_kernel_size // 2))
5555
self.retina.add_module("retina_relu2", nn.ReLU())
5656

5757
if isinstance(ventral, int):

training/train_imagenet.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,37 @@
2020
# parser.add_argument('--d-vvs', default=2, type=int, help='ventral depth')
2121
# parser.add_argument('--cache', default=250, type=int, help='cache size')
2222
parser.add_argument('--root', type=str, help='root')
23+
parser.add_argument('--baseline', type=bool, help='train baseline?')
2324
args = parser.parse_args()
2425

2526
bottlenecks = [1, 2, 4, 8, 16, 32]
2627

2728
n_trials = 5
2829

29-
param_grid = ParameterGrid({
30-
'n_bn': bottlenecks,
31-
'a': list(range(n_trials))
32-
})
30+
if args.baseline:
31+
retina_ch = 64
32+
n_bn = 64
33+
rep = args.arr
3334

34-
params = param_grid[args.arr]
35-
n_bn = params['n_bn']
36-
rep = params['a']
35+
model_file = f'resnet50_baseline_{rep}'
36+
else:
37+
retina_ch = 32
38+
param_grid = ParameterGrid({
39+
'n_bn': bottlenecks,
40+
'a': list(range(n_trials))
41+
})
42+
43+
params = param_grid[args.arr]
44+
n_bn = params['n_bn']
45+
rep = params['a']
46+
47+
model_file = f'resnet50_{n_bn}_{rep}'
3748

3849
# n_bn = bns[args.arr % 6]
3950
# rep = args.arr // 6
4051

4152
dir = '/scratch/ewah1g13/models/'
42-
model_file = f'resnet50_{n_bn}_{rep}'
53+
4354
# log_file = f'./logs/imagenet/resnet50_{n_bn}_{rep}.csv'
4455

4556
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
@@ -73,11 +84,11 @@
7384
testset = ImageNetHDF5(f'{args.root}/val', transform=test_transform, cache_size=1000)
7485

7586
# create data loaders
76-
trainloader = DataLoader(trainset, batch_size=1024, shuffle=True, pin_memory=True, num_workers=16)
77-
testloader = DataLoader(testset, batch_size=1024, shuffle=False, pin_memory=True, num_workers=16)
87+
trainloader = DataLoader(trainset, batch_size=1024, shuffle=True, pin_memory=True, num_workers=64)
88+
testloader = DataLoader(testset, batch_size=1024, shuffle=False, pin_memory=True, num_workers=64)
7889

7990
# model = ImageNetModel(n_bn, args.d_vvs, n_inch=3)
80-
model = RetinalBottleneckModel(n_bn, 'resnet50', n_out=1000, n_inch=3, retina_kernel_size=7)
91+
model = RetinalBottleneckModel(n_bn, 'resnet50', n_out=1000, n_inch=3, retina_kernel_size=7, retina_ch=retina_ch)
8192
model = nn.DataParallel(model)
8293
# print(model)
8394

0 commit comments

Comments
 (0)