Skip to content

Commit bd45203

Browse files
committed
🚧 wip
1 parent 71a8c99 commit bd45203

File tree

3 files changed

+118
-5
lines changed

3 files changed

+118
-5
lines changed

pytorch_lightning/trainer/evaluation_loop.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ class TrainerEvaluationLoopMixin(ABC):
180180
tpu_id: int
181181
verbose_test: bool
182182
running_sanity_check: bool
183-
testing: bool
184183
amp_backend: AMPType
185184

186185
# Callback system
@@ -372,9 +371,10 @@ def _evaluate(
372371

373372
# track outputs for collation
374373
if output is not None:
375-
do_write_preds = self.testing and isinstance(output, EvalResult) and not self.running_sanity_check
376-
# Add predictions to our prediction collection if they are found in outputs
377-
if do_write_preds:
374+
375+
# Add step predictions to prediction collection to write later
376+
do_write_predictions = is_result_obj and test_mode
377+
if do_write_predictions:
378378
predictions.add(output.pop('predictions', None))
379379

380380
dl_outputs.append(output)

tests/base/model_test_steps.py

+51
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import random
12
from abc import ABC
23
from collections import OrderedDict
34

45
import torch
56

7+
from pytorch_lightning import EvalResult
8+
69

710
class TestStepVariations(ABC):
811
"""
@@ -91,3 +94,51 @@ def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kw
9194

9295
def test_step__empty(self, batch, batch_idx, *args, **kwargs):
9396
return {}
97+
98+
99+
def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None):
100+
"""Lightning calls this inside the training loop"""
101+
"""
102+
Default, baseline test_step
103+
:param batch:
104+
:return:
105+
"""
106+
x, y = batch
107+
x = x.view(x.size(0), -1)
108+
y_hat = self(x)
109+
110+
loss_test = self.loss(y, y_hat)
111+
112+
# acc
113+
labels_hat = torch.argmax(y_hat, dim=1)
114+
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
115+
test_acc = torch.tensor(test_acc)
116+
117+
test_acc = test_acc.type_as(x)
118+
119+
# Do regular EvalResult Logging
120+
result = EvalResult(checkpoint_on=loss_test)
121+
result.log('test_loss', loss_test)
122+
result.log('test_acc', test_acc)
123+
124+
#lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)]
125+
# int_outputs = [random.randint(500, 1000) for i in range(batch_size)]
126+
#nested_lst = [[x] for x in int_outputs]
127+
#lst_of_dicts = [{k: v} for k, v in zip(lst_of_str, int_outputs)]
128+
129+
# This is passed in from pytest via parameterization
130+
option = getattr(self, 'test_option', 0)
131+
132+
lazy_ids = torch.arange(batch_idx * self.batch_size, (batch_idx + 1) * x.size(0))
133+
134+
# Base
135+
if option == 0:
136+
result.write('idxs', lazy_ids)
137+
result.write('preds', labels_hat)
138+
139+
# Check mismatching tensor len
140+
elif option == 1:
141+
result.write('idxs', torch.cat((lazy_ids, lazy_ids)))
142+
result.write('preds', labels_hat)
143+
144+
return result

tests/core/test_results.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
import sys
2+
from pathlib import Path
3+
14
import pytest
25
import torch
36
import torch.distributed as dist
47
import torch.multiprocessing as mp
8+
from pytorch_lightning import Trainer
59
from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult
610
import tests.base.develop_utils as tutils
7-
import sys
11+
12+
from tests.base import EvalModelTemplate
13+
from tests.base.datamodules import TrialMNISTDataModule
814

915

1016
def _setup_ddp(rank, worldsize):
@@ -35,3 +41,59 @@ def test_result_reduce_ddp(result_cls):
3541

3642
worldsize = 2
3743
mp.spawn(_ddp_test_fn, args=(worldsize, result_cls), nprocs=worldsize)
44+
45+
46+
@pytest.mark.parametrize(
47+
"option,do_train",
48+
[
49+
pytest.param(
50+
0, True, id='full_loop'
51+
),
52+
pytest.param(
53+
0, False, id='test_only'
54+
),
55+
pytest.param(
56+
1, False, id='test_only_mismatching_tensor', marks=pytest.mark.xfail(raises=ValueError, match="Mism.*")
57+
),
58+
]
59+
)
60+
def test_result_obj_predictions(tmpdir, option, do_train):
61+
tutils.reset_seed()
62+
63+
dm = TrialMNISTDataModule(tmpdir)
64+
65+
model = EvalModelTemplate()
66+
model.test_option = option
67+
model.prediction_file = Path('predictions.pt')
68+
model.test_step = model.test_step_result_preds
69+
model.test_step_end = None
70+
model.test_epoch_end = None
71+
model.test_end = None
72+
73+
if model.prediction_file.exists():
74+
model.prediction_file.unlink()
75+
76+
trainer = Trainer(
77+
default_root_dir=tmpdir,
78+
max_epochs=3,
79+
weights_summary=None,
80+
deterministic=True,
81+
)
82+
83+
# Prediction file shouldn't exist yet because we haven't done anything
84+
assert not model.prediction_file.exists()
85+
86+
if do_train:
87+
result = trainer.fit(model, dm)
88+
assert result == 1
89+
result = trainer.test(datamodule=dm)
90+
result = result[0]
91+
assert result['test_loss'] < 0.6
92+
assert result['test_acc'] > 0.8
93+
else:
94+
result = trainer.test(model, datamodule=dm)
95+
96+
# check prediction file now exists and is of expected length
97+
assert model.prediction_file.exists()
98+
predictions = torch.load(model.prediction_file)
99+
assert len(predictions) == len(dm.mnist_test)

0 commit comments

Comments
 (0)