-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove error when test dataloader used in test #1495
Merged
Merged
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
9262acc
remove error when test dataloader used in test
williamFalcon b433577
remove error when test dataloader used in test
williamFalcon ecc7d2a
remove error when test dataloader used in test
williamFalcon 4bbf9a7
remove error when test dataloader used in test
williamFalcon 11404ca
remove error when test dataloader used in test
williamFalcon 788cb01
remove error when test dataloader used in test
williamFalcon 8bf9b4d
fix lost model reference
williamFalcon 5b57c54
remove error when test dataloader used in test
williamFalcon 168c96c
fix lost model reference
williamFalcon 1211b57
moved optimizer types
williamFalcon 7eb08e6
moved optimizer types
williamFalcon 27b435f
moved optimizer types
williamFalcon 39b9cfb
moved optimizer types
williamFalcon 77be73d
moved optimizer types
williamFalcon 86f681c
moved optimizer types
williamFalcon a027eda
moved optimizer types
williamFalcon 03c26af
moved optimizer types
williamFalcon 9839cf3
added tests for warning
williamFalcon 329f887
fix lost model reference
williamFalcon 77b98e5
fix lost model reference
williamFalcon f51523e
added tests for warning
williamFalcon 8aa5b8d
added tests for warning
williamFalcon 5cba21d
refactoring
Borda 5555b41
refactoring
Borda 7dfcb8f
fix imports
Borda 9748952
refactoring
Borda ca64314
fix imports
Borda 9275762
refactoring
Borda 686aa34
fix tests
Borda b9626de
fix mnist
Borda 0460e23
flake8
Borda e2bb08d
review
Borda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from abc import ABC | ||
|
||
from torch import optim | ||
|
||
|
||
class ConfigureOptimizersPool(ABC): | ||
def configure_optimizers(self): | ||
""" | ||
return whatever optimizers we want here. | ||
:return: list of optimizers | ||
""" | ||
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | ||
return optimizer | ||
|
||
def configure_optimizers_empty(self): | ||
return None | ||
|
||
def configure_optimizers_lbfgs(self): | ||
""" | ||
return whatever optimizers we want here. | ||
:return: list of optimizers | ||
""" | ||
optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) | ||
return optimizer | ||
|
||
def configure_optimizers_multiple_optimizers(self): | ||
""" | ||
return whatever optimizers we want here. | ||
:return: list of optimizers | ||
""" | ||
# try no scheduler for this model (testing purposes) | ||
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | ||
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | ||
return optimizer1, optimizer2 | ||
|
||
def configure_optimizers_single_scheduler(self): | ||
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | ||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) | ||
return [optimizer], [lr_scheduler] | ||
|
||
def configure_optimizers_multiple_schedulers(self): | ||
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | ||
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | ||
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) | ||
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) | ||
|
||
return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] | ||
|
||
def configure_optimizers_mixed_scheduling(self): | ||
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | ||
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | ||
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 4, gamma=0.1) | ||
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) | ||
|
||
return [optimizer1, optimizer2], \ | ||
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2] | ||
|
||
def configure_optimizers_reduce_lr_on_plateau(self): | ||
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | ||
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) | ||
return [optimizer], [lr_scheduler] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from tests.base.datasets import TrialMNIST | ||
from pytorch_lightning.core.lightning import LightningModule | ||
from tests.base.eval_model_optimizers import ConfigureOptimizersPool | ||
from tests.base.eval_model_test_dataloaders import TestDataloaderVariations | ||
from tests.base.eval_model_test_epoch_ends import TestEpochEndVariations | ||
from tests.base.eval_model_test_steps import TestStepVariations | ||
from tests.base.eval_model_train_dataloaders import TrainDataloaderVariations | ||
from tests.base.eval_model_train_steps import TrainingStepVariations | ||
from tests.base.eval_model_valid_dataloaders import ValDataloaderVariations | ||
from tests.base.eval_model_valid_epoch_ends import ValidationEpochEndVariations | ||
from tests.base.eval_model_valid_steps import ValidationStepVariations | ||
from tests.base.eval_model_utils import ModelTemplateUtils | ||
|
||
|
||
class EvalModelTemplate( | ||
ModelTemplateUtils, | ||
TrainingStepVariations, | ||
ValidationStepVariations, | ||
ValidationEpochEndVariations, | ||
TestStepVariations, | ||
TestEpochEndVariations, | ||
TrainDataloaderVariations, | ||
ValDataloaderVariations, | ||
TestDataloaderVariations, | ||
ConfigureOptimizersPool, | ||
LightningModule | ||
): | ||
""" | ||
This template houses all combinations of model configurations we want to test | ||
""" | ||
def __init__(self, hparams): | ||
"""Pass in parsed HyperOptArgumentParser to the model.""" | ||
# init superclass | ||
super().__init__() | ||
self.hparams = hparams | ||
|
||
# if you specify an example input, the summary will show input/output for each layer | ||
self.example_input_array = torch.rand(5, 28 * 28) | ||
|
||
# build model | ||
self.__build_model() | ||
|
||
def __build_model(self): | ||
""" | ||
Simple model for testing | ||
:return: | ||
""" | ||
self.c_d1 = nn.Linear( | ||
in_features=self.hparams.in_features, | ||
out_features=self.hparams.hidden_dim | ||
) | ||
self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim) | ||
self.c_d1_drop = nn.Dropout(self.hparams.drop_prob) | ||
|
||
self.c_d2 = nn.Linear( | ||
in_features=self.hparams.hidden_dim, | ||
out_features=self.hparams.out_features | ||
) | ||
|
||
def forward(self, x): | ||
x = self.c_d1(x) | ||
x = torch.tanh(x) | ||
x = self.c_d1_bn(x) | ||
x = self.c_d1_drop(x) | ||
|
||
x = self.c_d2(x) | ||
logits = F.log_softmax(x, dim=1) | ||
|
||
return logits | ||
|
||
def loss(self, labels, logits): | ||
nll = F.nll_loss(logits, labels) | ||
return nll | ||
|
||
def prepare_data(self): | ||
_ = TrialMNIST(root=self.hparams.data_root, train=True, download=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class TestDataloaderVariations(ABC): | ||
|
||
@abstractmethod | ||
def dataloader(self, train: bool): | ||
"""placeholder""" | ||
|
||
def test_dataloader(self): | ||
return self.dataloader(train=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from abc import ABC | ||
|
||
import torch | ||
|
||
|
||
class TestEpochEndVariations(ABC): | ||
|
||
def test_epoch_end(self, outputs): | ||
""" | ||
Called at the end of validation to aggregate outputs | ||
:param outputs: list of individual outputs of each validation step | ||
:return: | ||
""" | ||
# if returned a scalar from test_step, outputs is a list of tensor scalars | ||
# we return just the average in this case (if we want) | ||
# return torch.stack(outputs).mean() | ||
test_loss_mean = 0 | ||
test_acc_mean = 0 | ||
for output in outputs: | ||
test_loss = self.get_output_metric(output, 'test_loss') | ||
|
||
# reduce manually when using dp | ||
if self.trainer.use_dp: | ||
test_loss = torch.mean(test_loss) | ||
test_loss_mean += test_loss | ||
|
||
# reduce manually when using dp | ||
test_acc = self.get_output_metric(output, 'test_acc') | ||
if self.trainer.use_dp: | ||
test_acc = torch.mean(test_acc) | ||
|
||
test_acc_mean += test_acc | ||
|
||
test_loss_mean /= len(outputs) | ||
test_acc_mean /= len(outputs) | ||
|
||
metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} | ||
result = {'progress_bar': metrics_dict, 'log': metrics_dict} | ||
return result |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.