Skip to content

Commit 768579d

Browse files
justusschockBorda
authored andcommitted
Rework of Sklearn Metrics (#1327)
* Create utils.py * Create __init__.py * redo sklearn metrics * add some more metrics * add sklearn metrics * Create __init__.py * redo sklearn metrics * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <[email protected]> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <[email protected]> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <[email protected]> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <[email protected]> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <[email protected]> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <[email protected]> * add sklearn metrics * start adding sklearn tests * fix typo * return x and y only for curves * fix typo * add missing tests for sklearn funcs * imports * __all__ * imports * fix sklearn arguments * fix imports * update requirements * Update CHANGELOG.md * Update test_sklearn_metrics.py * formatting * formatting * format * fix all warnings and formatting problems * Update environment.yml * Update requirements-extra.txt * Update environment.yml * Update requirements-extra.txt * fix all warnings and formatting problems * Update CHANGELOG.md * docs * inherit * docs inherit. * docs * Apply suggestions from code review Co-authored-by: Nicki Skafte <[email protected]> * docs * req * min * Apply suggestions from code review Co-authored-by: Tullie Murrell <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Tullie Murrell <[email protected]> (cherry picked from commit bd49b07)
1 parent 48113f9 commit 768579d

17 files changed

+983
-28
lines changed

.circleci/config.yml

+7-3
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,12 @@ references:
6464
name: Make Documentation
6565
command: |
6666
# First run the same pipeline as Read-The-Docs
67-
sudo apt-get update && sudo apt-get install -y cmake
68-
sudo pip install -r docs/requirements.txt
67+
# apt-get update && apt-get install -y cmake
68+
# using: https://hub.docker.com/r/readthedocs/build
69+
# we need to use py3.7 ot higher becase of an issue with metaclass inheritence
70+
pyenv global 3.7.3
71+
python --version
72+
pip install -r docs/requirements.txt
6973
cd docs; make clean; make html --debug --jobs 2 SPHINXOPTS="-W"
7074
7175
test_docs: &test_docs
@@ -81,7 +85,7 @@ jobs:
8185

8286
Build-Docs:
8387
docker:
84-
- image: circleci/python:3.7
88+
- image: readthedocs/build:latest
8589
steps:
8690
- checkout
8791
- *make_docs

.github/workflows/ci-testing.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ jobs:
6868
- name: Set min. dependencies
6969
if: matrix.requires == 'minimal'
7070
run: |
71-
python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)"
72-
python -c "req = open('requirements-extra.txt').read().replace('>', '=') ; open('requirements-extra.txt', 'w').write(req)"
73-
python -c "req = open('tests/requirements-devel.txt').read().replace('>', '=') ; open('tests/requirements-devel.txt', 'w').write(req)"
71+
python -c "req = open('requirements.txt').read().replace('>=', '==') ; open('requirements.txt', 'w').write(req)"
72+
python -c "req = open('requirements-extra.txt').read().replace('>=', '==') ; open('requirements-extra.txt', 'w').write(req)"
73+
python -c "req = open('tests/requirements-devel.txt').read().replace('>=', '==') ; open('tests/requirements-devel.txt', 'w').write(req)"
7474
7575
# Note: This uses an internal pip API and may not always work
7676
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow

CHANGELOG.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7-
87
## [unreleased] - YYYY-MM-DD
98

109
### Added
@@ -23,7 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2322
### Added
2423

2524
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
26-
- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
25+
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
26+
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
2727
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723))
2828
- Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907))
2929
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)

docs/source/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
'sphinx.ext.linkcode',
9191
'sphinx.ext.autosummary',
9292
'sphinx.ext.napoleon',
93+
'sphinx.ext.imgmath',
9394
'recommonmark',
9495
'sphinx.ext.autosectionlabel',
9596
# 'm2r',

environment.yml

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ dependencies:
2626
- autopep8
2727
- check-manifest
2828
- twine==1.13.0
29+
- pillow<7.0.0
30+
- scipy>=0.13.3
31+
- scikit-learn>=0.20.0
32+
2933

3034
- pip:
3135
- test-tube>=0.7.5

pl_examples/domain_templates/computer_vision_fine_tuning.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from tempfile import TemporaryDirectory
2828
from typing import Optional, Generator, Union
2929

30+
from torch.nn import Module
31+
3032
import pytorch_lightning as pl
3133
import torch
3234
import torch.nn.functional as F
@@ -47,7 +49,7 @@
4749
# --- Utility functions ---
4850

4951

50-
def _make_trainable(module: torch.nn.Module) -> None:
52+
def _make_trainable(module: Module) -> None:
5153
"""Unfreezes a given module.
5254
5355
Args:
@@ -58,7 +60,7 @@ def _make_trainable(module: torch.nn.Module) -> None:
5860
module.train()
5961

6062

61-
def _recursive_freeze(module: torch.nn.Module,
63+
def _recursive_freeze(module: Module,
6264
train_bn: bool = True) -> None:
6365
"""Freezes the layers of a given module.
6466
@@ -80,7 +82,7 @@ def _recursive_freeze(module: torch.nn.Module,
8082
_recursive_freeze(module=child, train_bn=train_bn)
8183

8284

83-
def freeze(module: torch.nn.Module,
85+
def freeze(module: Module,
8486
n: Optional[int] = None,
8587
train_bn: bool = True) -> None:
8688
"""Freezes the layers up to index n (if n is not None).
@@ -101,7 +103,7 @@ def freeze(module: torch.nn.Module,
101103
_make_trainable(module=child)
102104

103105

104-
def filter_params(module: torch.nn.Module,
106+
def filter_params(module: Module,
105107
train_bn: bool = True) -> Generator:
106108
"""Yields the trainable parameters of a given module.
107109
@@ -124,7 +126,7 @@ def filter_params(module: torch.nn.Module,
124126
yield param
125127

126128

127-
def _unfreeze_and_add_param_group(module: torch.nn.Module,
129+
def _unfreeze_and_add_param_group(module: Module,
128130
optimizer: Optimizer,
129131
lr: Optional[float] = None,
130132
train_bn: bool = True):

pytorch_lightning/core/grads.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from typing import Dict, Union
55

66
import torch
7+
from torch.nn import Module
78

89

9-
class GradInformation(torch.nn.Module):
10+
class GradInformation(Module):
1011

1112
def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
1213
"""Compute each parameter's gradient's norm and their overall norm.

pytorch_lightning/core/hooks.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from torch import Tensor
5+
from torch.nn import Module
56
from torch.optim.optimizer import Optimizer
67
from pytorch_lightning.utilities import move_data_to_device
78

@@ -14,7 +15,7 @@
1415
APEX_AVAILABLE = True
1516

1617

17-
class ModelHooks(torch.nn.Module):
18+
class ModelHooks(Module):
1819

1920
# TODO: remove in v0.9.0
2021
def on_sanity_check_start(self):

pytorch_lightning/metrics/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,9 @@
2222
2323
2424
"""
25+
26+
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
27+
from pytorch_lightning.metrics.sklearn import (
28+
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
29+
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
30+
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric

pytorch_lightning/metrics/metric.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Optional, Union
2+
from typing import Any, Optional
33

44
import torch
55
import torch.distributed
6+
from torch.nn import Module
67

78
from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric
89
from pytorch_lightning.utilities.apply_func import apply_to_collection
@@ -11,7 +12,7 @@
1112
__all__ = ['Metric', 'TensorMetric', 'NumpyMetric']
1213

1314

14-
class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
15+
class Metric(ABC, DeviceDtypeModuleMixin, Module):
1516
"""
1617
Abstract base class for metric implementation.
1718

0 commit comments

Comments
 (0)