Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove error when test dataloader used in test #1495

Merged
merged 32 commits into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9262acc
remove error when test dataloader used in test
williamFalcon Apr 15, 2020
b433577
remove error when test dataloader used in test
williamFalcon Apr 15, 2020
ecc7d2a
remove error when test dataloader used in test
williamFalcon Apr 15, 2020
4bbf9a7
remove error when test dataloader used in test
williamFalcon Apr 15, 2020
11404ca
remove error when test dataloader used in test
williamFalcon Apr 15, 2020
788cb01
remove error when test dataloader used in test
williamFalcon Apr 15, 2020
8bf9b4d
fix lost model reference
williamFalcon Apr 15, 2020
5b57c54
remove error when test dataloader used in test
williamFalcon Apr 15, 2020
168c96c
fix lost model reference
williamFalcon Apr 15, 2020
1211b57
moved optimizer types
williamFalcon Apr 15, 2020
7eb08e6
moved optimizer types
williamFalcon Apr 15, 2020
27b435f
moved optimizer types
williamFalcon Apr 15, 2020
39b9cfb
moved optimizer types
williamFalcon Apr 15, 2020
77be73d
moved optimizer types
williamFalcon Apr 15, 2020
86f681c
moved optimizer types
williamFalcon Apr 15, 2020
a027eda
moved optimizer types
williamFalcon Apr 15, 2020
03c26af
moved optimizer types
williamFalcon Apr 15, 2020
9839cf3
added tests for warning
williamFalcon Apr 15, 2020
329f887
fix lost model reference
williamFalcon Apr 15, 2020
77b98e5
fix lost model reference
williamFalcon Apr 15, 2020
f51523e
added tests for warning
williamFalcon Apr 15, 2020
8aa5b8d
added tests for warning
williamFalcon Apr 15, 2020
5cba21d
refactoring
Borda Apr 15, 2020
5555b41
refactoring
Borda Apr 15, 2020
7dfcb8f
fix imports
Borda Apr 15, 2020
9748952
refactoring
Borda Apr 15, 2020
ca64314
fix imports
Borda Apr 15, 2020
9275762
refactoring
Borda Apr 15, 2020
686aa34
fix tests
Borda Apr 15, 2020
b9626de
fix mnist
Borda Apr 15, 2020
0460e23
flake8
Borda Apr 15, 2020
e2bb08d
review
Borda Apr 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed test for no test dataloader in .fit ([#1495](https://github.com/PyTorchLightning/pytorch-lightning/pull/1495))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Removed test for no test dataloader in .fit ([#1495](https://github.com/PyTorchLightning/pytorch-lightning/pull/1495))
- Removed test for no test dataloader in `.fit()` ([#1495](https://github.com/PyTorchLightning/pytorch-lightning/pull/1495))

- Removed duplicated module `pytorch_lightning.utilities.arg_parse` for loading CLI arguments ([#1167](https://github.com/PyTorchLightning/pytorch-lightning/issues/1167))
- Removed wandb logger's `finalize` method ([#1193](https://github.com/PyTorchLightning/pytorch-lightning/pull/1193))
- Dropped `torchvision` dependency in tests and added own MNIST dataset class instead ([#986](https://github.com/PyTorchLightning/pytorch-lightning/issues/986))
Expand Down
27 changes: 23 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,9 @@ def test(self, model: Optional[LightningModule] = None, test_dataloaders: Option
else:
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)

# give proper warnings if user only passed in loader without hooks
self.check_testing_model_configuration(model, test_dataloaders)

if model is not None:
self.model = model
self.fit(model)
Expand Down Expand Up @@ -1012,10 +1015,26 @@ def check_model_configuration(self, model: LightningModule):
'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to'
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
)
else:
if self.is_overriden('test_step', model):
raise MisconfigurationException('You have defined `test_step()`,'
' but have not passed in a `test_dataloader()`.')

def check_testing_model_configuration(self, model: LightningModule, test_dataloader: DataLoader):

has_test_step = self.is_overriden('test_step', model)
has_test_epoch_end = self.is_overriden('test_epoch_end', model)
gave_test_loader = test_dataloader is not None

if gave_test_loader and not has_test_step:
raise MisconfigurationException('You passed in a `test_dataloader` but did not implement '
' `test_step()`')

if has_test_step and not gave_test_loader:
raise MisconfigurationException('You defined `test_step()` but did not implement '
' `test_dataloader` nor passed in `.fit(test_dataloaders`.')

if has_test_step and gave_test_loader and not has_test_epoch_end:
rank_zero_warn(
'You passed in a `test_dataloader` and have defined a `test_step()`, you may also want to'
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
)


class _PatchDataLoader(object):
Expand Down
1 change: 1 addition & 0 deletions tests/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from tests.base.models import TestModelBase, DictHparamsModel
from tests.base.template_test_model import TemplateTestModel
from tests.base.mixins import (
LightEmptyTestStep,
LightValidationStepMixin,
Expand Down
59 changes: 59 additions & 0 deletions tests/base/config_optimizers_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from torch import optim


class ConfigureOptimizersVariationsMixin:
def configure_optimizers(self):
"""
return whatever optimizers we want here.
:return: list of optimizers
"""
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer

def configure_optimizers_empty(self):
return None

def configure_optimizers_lbfgs(self):
"""
return whatever optimizers we want here.
:return: list of optimizers
"""
optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
return optimizer

def configure_optimizers_multiple_optimizers(self):
"""
return whatever optimizers we want here.
:return: list of optimizers
"""
# try no scheduler for this model (testing purposes)
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer1, optimizer2

def configure_optimizers_single_scheduler(self):
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
return [optimizer], [lr_scheduler]

def configure_optimizers_multiple_schedulers(self):
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)

return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]

def configure_optimizers_mixed_scheduling(self):
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 4, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)

return [optimizer1, optimizer2], \
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2]

def configure_optimizers_reduce_lr_on_plateau(self):
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return [optimizer], [lr_scheduler]
80 changes: 80 additions & 0 deletions tests/base/template_test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from tests.base.datasets import TestingMNIST
from pytorch_lightning.core.lightning import LightningModule
from tests.base.training_step_variations import TrainingStepVariationsMixin
from tests.base.test_step_variations import TestStepVariationsMixin
from tests.base.validation_step_variations import ValidationStepVariationsMixin
from tests.base.test_epoch_end_variations import TestEpochEndVariationsMixin
from tests.base.config_optimizers_variations import ConfigureOptimizersVariationsMixin
from tests.base.val_dataloader_variations import ValDataloaderVariationsMixin
from tests.base.train_dataloader_variations import TrainDataloaderVariationsMixin
from tests.base.test_dataloader_variations import TestDataloaderVariationsMixin
from tests.base.validation_epoch_end_variations import ValidationEpochEndVariationsMixin
from tests.base.template_test_model_utils import TemplateTestModelUtilsMixin


class TemplateTestModel(
TrainingStepVariationsMixin,
ValidationStepVariationsMixin,
ValidationEpochEndVariationsMixin,
TestStepVariationsMixin,
TestEpochEndVariationsMixin,
TrainDataloaderVariationsMixin,
ValDataloaderVariationsMixin,
TestDataloaderVariationsMixin,
ConfigureOptimizersVariationsMixin,
TemplateTestModelUtilsMixin,
LightningModule
):
"""
This template houses all combinations of model configurations we want to test
"""
def __init__(self, hparams):
"""Pass in parsed HyperOptArgumentParser to the model."""
# init superclass
super().__init__()
self.hparams = hparams

# if you specify an example input, the summary will show input/output for each layer
self.example_input_array = torch.rand(5, 28 * 28)

# build model
self.__build_model()

def __build_model(self):
"""
Simple model for testing
:return:
"""
self.c_d1 = nn.Linear(
in_features=self.hparams.in_features,
out_features=self.hparams.hidden_dim
)
self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)

self.c_d2 = nn.Linear(
in_features=self.hparams.hidden_dim,
out_features=self.hparams.out_features
)

def forward(self, x):
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
x = self.c_d1_drop(x)

x = self.c_d2(x)
logits = F.log_softmax(x, dim=1)

return logits

def loss(self, labels, logits):
nll = F.nll_loss(logits, labels)
return nll

def prepare_data(self):
_ = TestingMNIST(root=self.hparams.data_root, train=True, download=True)
23 changes: 23 additions & 0 deletions tests/base/template_test_model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from torch.utils.data import DataLoader
from tests.base.datasets import TestingMNIST


class TemplateTestModelUtilsMixin:

def dataloader(self, train):
dataset = TestingMNIST(root=self.hparams.data_root, train=train, download=False)

loader = DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True
)

return loader

def get_output_metric(self, output, name):
if isinstance(output, dict):
val = output[name]
else: # if it is 2level deep -> per dataloader and per batch
val = sum(out[name] for out in output) / len(output)
return val
11 changes: 11 additions & 0 deletions tests/base/test_dataloader_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tests.base.model_utils as mutils


class TestDataloaderVariationsMixin:

def test_dataloader(self):
return mutils.dataloader(
train=False,
data_root=self.hparams.data_root,
batch_size=self.hparams.batch_size,
)
36 changes: 36 additions & 0 deletions tests/base/test_epoch_end_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch


class TestEpochEndVariationsMixin:
def test_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
# if returned a scalar from test_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
# return torch.stack(outputs).mean()
test_loss_mean = 0
test_acc_mean = 0
for output in outputs:
test_loss = self.get_output_metric(output, 'test_loss')

# reduce manually when using dp
if self.trainer.use_dp:
test_loss = torch.mean(test_loss)
test_loss_mean += test_loss

# reduce manually when using dp
test_acc = self.get_output_metric(output, 'test_acc')
if self.trainer.use_dp:
test_acc = torch.mean(test_acc)

test_acc_mean += test_acc

test_loss_mean /= len(outputs)
test_acc_mean /= len(outputs)

metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
result = {'progress_bar': metrics_dict, 'log': metrics_dict}
return result
50 changes: 50 additions & 0 deletions tests/base/test_step_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from collections import OrderedDict
import torch


class TestStepVariationsMixin:
"""
Houses all variations of test steps
"""
def test_step(self, batch, batch_idx, dataloader_idx, **kwargs):
"""
Default, baseline test_step
:param batch:
:return:
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)

loss_test = self.loss(y, y_hat)

# acc
labels_hat = torch.argmax(y_hat, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
test_acc = torch.tensor(test_acc)

test_acc = test_acc.as_type(x)

# alternate possible outputs to test
if batch_idx % 1 == 0:
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
})
return output
if batch_idx % 2 == 0:
return test_acc

if batch_idx % 3 == 0:
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
'test_dic': {'test_loss_a': loss_test}
})
return output
if batch_idx % 5 == 0:
output = OrderedDict({
f'test_loss_{dataloader_idx}': loss_test,
f'test_acc_{dataloader_idx}': test_acc,
})
return output
11 changes: 11 additions & 0 deletions tests/base/train_dataloader_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tests.base.model_utils as mutils


class TrainDataloaderVariationsMixin:

def train_dataloader(self):
return mutils.dataloader(
train=True,
data_root=self.hparams.data_root,
batch_size=self.hparams.batch_size,
)
29 changes: 29 additions & 0 deletions tests/base/training_step_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from collections import OrderedDict


class TrainingStepVariationsMixin:
"""
Houses all variations of training steps
"""
def training_step(self, batch, batch_idx, optimizer_idx=None):
"""Lightning calls this inside the training loop"""
# forward pass
x, y = batch
x = x.view(x.size(0), -1)

y_hat = self(x)

# calculate loss
loss_val = self.loss(y, y_hat)

# alternate possible outputs to test
if self.trainer.batch_idx % 1 == 0:
output = OrderedDict({
'loss': loss_val,
'progress_bar': {'some_val': loss_val * loss_val},
'log': {'train_some_val': loss_val * loss_val},
})
return output

if self.trainer.batch_idx % 2 == 0:
return loss_val
11 changes: 11 additions & 0 deletions tests/base/val_dataloader_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tests.base.model_utils as mutils


class ValDataloaderVariationsMixin:

def val_dataloader(self):
return mutils.dataloader(
train=False,
data_root=self.hparams.data_root,
batch_size=self.hparams.batch_size,
)
Loading