Skip to content

Commit b0e93c6

Browse files
committed
being able to load weights from a checkpoint and to run test successfully.
1 parent 2a19a3b commit b0e93c6

File tree

4 files changed

+61
-10
lines changed

4 files changed

+61
-10
lines changed

examples/mnist.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
import pickle
1212

1313

14-
def train(epochs=10, dryrun=False, debug=False, batch_size=64,
14+
def train(epochs=10, batch_size=64, dryrun=False, debug=False,
1515
save_file='data/mnist_model_ckpt.pkl'):
1616

17-
dataset = MNIST('./data/MNIST/mnn_test.pickle')
17+
dataset = MNIST('./data/MNIST/mnn_train.pickle')
1818
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
1919
shuffle=True, collate_fn=lambda batch: batch)
2020

@@ -47,12 +47,41 @@ def train(epochs=10, dryrun=False, debug=False, batch_size=64,
4747

4848
print('saving checkpoint ...')
4949
with open(save_file, 'wb') as fh:
50-
save = net.state_dict(), net.config()
50+
save = net.state_dict(), net.get_config()
5151
pickle.dump(save, fh)
5252

5353

54-
def test(dryrun=False, debug=False):
55-
pass
54+
def test(checkpoint, batch_size=64):
55+
with open(checkpoint, 'rb') as fh:
56+
state_dict, config = pickle.load(fh)
57+
58+
dataset = MNIST('./data/MNIST/mnn_test.pickle')
59+
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
60+
shuffle=True, collate_fn=lambda batch: batch)
61+
62+
net = SequentialLayers([
63+
LinearLayer(28 * 28, 256),
64+
ReluLayer(),
65+
LinearLayer(256, 10),
66+
SoftmaxLayer()
67+
])
68+
net.load_weights(state_dict, config=config, verbose=True)
69+
70+
correct_cnt, inference_cnt = 0, 0
71+
for b, batch in enumerate(loader):
72+
images = Tensor([data for data, label in batch])
73+
images = images.unsqueeze(-1)
74+
labels = Tensor([label for data, label in batch])
75+
76+
scores = net(images).squeeze(-1)
77+
78+
preds = scores.argmax(-1)
79+
corrects = (preds == labels)
80+
correct_cnt += corrects.sum().item()
81+
inference_cnt += labels.shape[0]
82+
83+
accuracy = correct_cnt / inference_cnt
84+
print(f'test accuracy: {accuracy:.3f}')
5685

5786

5887
if __name__ == '__main__':

mnn/layer.py

+8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ def _state_dict(self):
1717
state_dict[key] = (shape, param)
1818
return state_dict
1919

20+
def _load_weights(self, state_dict, config=None):
21+
for path, (shape, param) in state_dict.items():
22+
name, key = path.split('.')
23+
assert name == self.name
24+
assert key in self.params
25+
assert shape == self.params[key].shape
26+
self.params[key] = Tensor(param)
27+
2028
def _accumulate_grads(self, key, val):
2129
reduced_val = self._batch_reduced(val)
2230
if key in self.grads:

mnn/seq_layers.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
from mnn.tensor import Tensor
32
from mnn.layer import *
43

@@ -120,10 +119,19 @@ def state_dict(self):
120119
state_dict[path] = val
121120
return state_dict
122121

123-
def config(self):
124-
return json.dumps({
125-
'layers': len(self.layers)
126-
})
122+
def get_config(self):
123+
return {'layers': len(self.layers)}
124+
125+
def load_weights(self, state_dict, config=None, verbose=False):
126+
assert config is not None
127+
assert len(self.layers) == config['layers']
128+
for path, value in state_dict.items():
129+
if verbose: print('loading weights to:', path)
130+
path_fields = path.split('.')
131+
l = int(path_fields.pop(0))
132+
subpath = '.'.join(path_fields)
133+
state_dict = {subpath: value}
134+
self.layers[l]._load_weights(state_dict)
127135

128136

129137
if __name__ == '__main__':

mnn/tensor.py

+6
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ def __mul__(self, x):
6363
else:
6464
return Tensor(self._data * x)
6565

66+
def __eq__(self, x):
67+
if isinstance(x, Tensor):
68+
return Tensor(self._data == x._data)
69+
else:
70+
raise NotImplemented
71+
6672
def __rmul__(self, x):
6773
if isinstance(x, Tensor):
6874
return Tensor(self._data * x._data)

0 commit comments

Comments
 (0)