Skip to content

Commit 3d65e2e

Browse files
williamFalconBorda
authored and
akarnachev
committed
Parity test (Lightning-AI#1284)
* adding test * adding test * added base parity model * added base parity model * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * move parity to benchmark * formatting * fixed gradient acc sched * move parity to benchmark * formatting * fixed gradient acc sched * skip for CPU * call last Co-authored-by: J. Borovec <[email protected]>
1 parent 097f8fe commit 3d65e2e

File tree

4 files changed

+153
-1
lines changed

4 files changed

+153
-1
lines changed

.drone.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ steps:
2222
- pip install -r ./tests/requirements.txt --user
2323
- pip list
2424
- python -c "import torch ; print(' & '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]) if torch.cuda.is_available() else 'only CPU')"
25-
- coverage run --source pytorch_lightning -m py.test pytorch_lightning tests -v --doctest-modules # --flake8
25+
- coverage run --source pytorch_lightning -m py.test pytorch_lightning tests benchmarks -v --doctest-modules # --flake8
2626
- coverage report
2727
- codecov --token $CODECOV_TOKEN # --pr $DRONE_PULL_REQUEST --build $DRONE_BUILD_NUMBER --branch $DRONE_BRANCH --commit $DRONE_COMMIT --tag $DRONE_TAG
2828
- python tests/collect_env_details.py

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Added parity test between a vanilla MNIST model and lightning model ([#1284](https://github.com/PyTorchLightning/pytorch-lightning/pull/1284))
1112
- Added same step loggers' metrics aggregation ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278))
1213
- Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))
1314
- Added support for hierarchical `dict` ([#1152](https://github.com/PyTorchLightning/pytorch-lightning/pull/1152))

benchmarks/__init__.py

Whitespace-only changes.

benchmarks/test_trainer_parity.py

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import os
2+
import time
3+
4+
import numpy as np
5+
import pytest
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
from torch.utils.data import DataLoader
10+
from torchvision import transforms
11+
from torchvision.datasets import MNIST
12+
13+
from pytorch_lightning import Trainer, LightningModule
14+
15+
16+
class ParityMNIST(LightningModule):
17+
18+
def __init__(self):
19+
super(ParityMNIST, self).__init__()
20+
self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128)
21+
self.c_d1_bn = nn.BatchNorm1d(128)
22+
self.c_d1_drop = nn.Dropout(0.3)
23+
self.c_d2 = nn.Linear(in_features=128, out_features=10)
24+
25+
def forward(self, x):
26+
x = x.view(x.size(0), -1)
27+
x = self.c_d1(x)
28+
x = torch.tanh(x)
29+
x = self.c_d1_bn(x)
30+
x = self.c_d1_drop(x)
31+
x = self.c_d2(x)
32+
return x
33+
34+
def training_step(self, batch, batch_nb):
35+
x, y = batch
36+
y_hat = self(x)
37+
loss = F.cross_entropy(y_hat, y)
38+
return {'loss': loss}
39+
40+
def configure_optimizers(self):
41+
return torch.optim.Adam(self.parameters(), lr=0.02)
42+
43+
def train_dataloader(self):
44+
return DataLoader(MNIST(os.getcwd(), train=True, download=True,
45+
transform=transforms.ToTensor()), batch_size=32)
46+
47+
48+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
49+
def test_pytorch_parity(tmpdir):
50+
"""
51+
Verify that the same pytorch and lightning models achieve the same results
52+
:param tmpdir:
53+
:return:
54+
"""
55+
num_epochs = 2
56+
num_rums = 3
57+
lightning_outs, pl_times = lightning_loop(ParityMNIST, num_rums, num_epochs)
58+
manual_outs, pt_times = vanilla_loop(ParityMNIST, num_rums, num_epochs)
59+
60+
# make sure the losses match exactly to 5 decimal places
61+
for pl_out, pt_out in zip(lightning_outs, manual_outs):
62+
np.testing.assert_almost_equal(pl_out, pt_out, 5)
63+
64+
65+
def set_seed(seed):
66+
np.random.seed(seed)
67+
torch.manual_seed(seed)
68+
if torch.cuda.is_available():
69+
torch.cuda.manual_seed(seed)
70+
71+
72+
def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
73+
"""
74+
Returns an array with the last loss from each epoch for each run
75+
"""
76+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
77+
errors = []
78+
times = []
79+
80+
for i in range(num_runs):
81+
time_start = time.perf_counter()
82+
83+
# set seed
84+
seed = i
85+
set_seed(seed)
86+
87+
# init model parts
88+
model = MODEL()
89+
dl = model.train_dataloader()
90+
optimizer = model.configure_optimizers()
91+
92+
# model to GPU
93+
model = model.to(device)
94+
95+
epoch_losses = []
96+
for epoch in range(num_epochs):
97+
98+
# run through full training set
99+
for j, batch in enumerate(dl):
100+
x, y = batch
101+
x = x.cuda(0)
102+
y = y.cuda(0)
103+
batch = (x, y)
104+
105+
loss_dict = model.training_step(batch, j)
106+
loss = loss_dict['loss']
107+
loss.backward()
108+
optimizer.step()
109+
optimizer.zero_grad()
110+
111+
# track last epoch loss
112+
epoch_losses.append(loss.item())
113+
114+
time_end = time.perf_counter()
115+
times.append(time_end - time_start)
116+
117+
errors.append(epoch_losses[-1])
118+
119+
return errors, times
120+
121+
122+
def lightning_loop(MODEL, num_runs=10, num_epochs=10):
123+
errors = []
124+
times = []
125+
126+
for i in range(num_runs):
127+
time_start = time.perf_counter()
128+
129+
# set seed
130+
seed = i
131+
set_seed(seed)
132+
133+
# init model parts
134+
model = MODEL()
135+
trainer = Trainer(
136+
max_epochs=num_epochs,
137+
show_progress_bar=False,
138+
weights_summary=None,
139+
gpus=1,
140+
early_stop_callback=False,
141+
checkpoint_callback=False
142+
)
143+
trainer.fit(model)
144+
145+
final_loss = trainer.running_loss.last().item()
146+
errors.append(final_loss)
147+
148+
time_end = time.perf_counter()
149+
times.append(time_end - time_start)
150+
151+
return errors, times

0 commit comments

Comments
 (0)