|
11 | 11 | import pickle
|
12 | 12 |
|
13 | 13 |
|
14 |
| -def train(epochs=10, dryrun=False, debug=False, batch_size=64, |
| 14 | +def train(epochs=10, batch_size=64, dryrun=False, debug=False, |
15 | 15 | save_file='data/mnist_model_ckpt.pkl'):
|
16 | 16 |
|
17 |
| - dataset = MNIST('./data/MNIST/mnn_test.pickle') |
| 17 | + dataset = MNIST('./data/MNIST/mnn_train.pickle') |
18 | 18 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
|
19 | 19 | shuffle=True, collate_fn=lambda batch: batch)
|
20 | 20 |
|
@@ -47,12 +47,41 @@ def train(epochs=10, dryrun=False, debug=False, batch_size=64,
|
47 | 47 |
|
48 | 48 | print('saving checkpoint ...')
|
49 | 49 | with open(save_file, 'wb') as fh:
|
50 |
| - save = net.state_dict(), net.config() |
| 50 | + save = net.state_dict(), net.get_config() |
51 | 51 | pickle.dump(save, fh)
|
52 | 52 |
|
53 | 53 |
|
54 |
| -def test(dryrun=False, debug=False): |
55 |
| - pass |
| 54 | +def test(checkpoint, batch_size=64): |
| 55 | + with open(checkpoint, 'rb') as fh: |
| 56 | + state_dict, config = pickle.load(fh) |
| 57 | + |
| 58 | + dataset = MNIST('./data/MNIST/mnn_test.pickle') |
| 59 | + loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, |
| 60 | + shuffle=True, collate_fn=lambda batch: batch) |
| 61 | + |
| 62 | + net = SequentialLayers([ |
| 63 | + LinearLayer(28 * 28, 256), |
| 64 | + ReluLayer(), |
| 65 | + LinearLayer(256, 10), |
| 66 | + SoftmaxLayer() |
| 67 | + ]) |
| 68 | + net.load_weights(state_dict, config=config, verbose=True) |
| 69 | + |
| 70 | + correct_cnt, inference_cnt = 0, 0 |
| 71 | + for b, batch in enumerate(loader): |
| 72 | + images = Tensor([data for data, label in batch]) |
| 73 | + images = images.unsqueeze(-1) |
| 74 | + labels = Tensor([label for data, label in batch]) |
| 75 | + |
| 76 | + scores = net(images).squeeze(-1) |
| 77 | + |
| 78 | + preds = scores.argmax(-1) |
| 79 | + corrects = (preds == labels) |
| 80 | + correct_cnt += corrects.sum().item() |
| 81 | + inference_cnt += labels.shape[0] |
| 82 | + |
| 83 | + accuracy = correct_cnt / inference_cnt |
| 84 | + print(f'test accuracy: {accuracy:.3f}') |
56 | 85 |
|
57 | 86 |
|
58 | 87 | if __name__ == '__main__':
|
|
0 commit comments