Skip to content

Commit f1c732a

Browse files
SkafteNickiNicki SkaftewilliamFalconBordaedenlightning
authored
Metric docs fix (#2209)
* fix docs * Update docs/source/metrics.rst * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <[email protected]> * Update docs/source/metrics.rst * Update docs/source/metrics.rst * Update metrics.rst * title * fix * fix for num_classes * chlog * nb classes * hints * zero division * add tests * Update metrics.rst * Update classification.py * Update classification.py * prune doctests * docs * Apply suggestions from code review * Apply suggestions from code review * flake8 * doctests * formatting * cleaning * formatting * formatting * doctests * flake8 * docs * rename * rename * typo Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: William Falcon <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka <[email protected]> Co-authored-by: edenlightning <[email protected]>
1 parent a5736d2 commit f1c732a

12 files changed

+658
-419
lines changed

CHANGELOG.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121

2222
### Added
2323

24-
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
25-
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
26-
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
24+
- Added metrics
25+
* Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
26+
* Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
27+
* Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
28+
* docs for all Metrics ([#2184](https://github.com/PyTorchLightning/pytorch-lightning/pull/2184), [#2209](https://github.com/PyTorchLightning/pytorch-lightning/pull/2209))
2729
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723))
2830
- Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907))
2931
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)

docs/source/metrics.rst

+105-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ Metrics are used to monitor model performance.
1212
In this package we provide two major pieces of functionality.
1313

1414
1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic.
15-
2. A collection of popular metrics already implemented for you.
15+
2. A collection of ready to use pupular metrics. There are two types of metrics: Class metrics and Functional metrics.
16+
3. A interface to call `sklearns metrics <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
1617

1718
Example::
1819

@@ -28,12 +29,17 @@ Out::
2829

2930
tensor(0.7500)
3031

32+
.. warning::
33+
The metrics package is still in development! If we're missing a metric or you find a mistake, please send a PR!
34+
to a few metrics. Please feel free to create an issue/PR if you have a proposed
35+
metric or have found a bug.
36+
3137
--------------
3238

3339
Implement a metric
3440
------------------
35-
You can implement metrics as either a PyTorch metric or a Numpy metric. Numpy metrics
36-
will slow down training, use PyTorch metrics when possible.
41+
You can implement metrics as either a PyTorch metric or a Numpy metric (It is recommend to use PyTorch metrics when possible,
42+
since Numpy metrics slow down training).
3743

3844
Use :class:`TensorMetric` to implement native PyTorch metrics. This class
3945
handles automated DDP syncing and converts all inputs and outputs to tensors.
@@ -76,7 +82,7 @@ Here's an example showing how to implement a NumpyMetric
7682

7783
Class Metrics
7884
-------------
79-
The following are metrics which can be instantiated as part of a module definition (even with just
85+
Class metrics can be instantiated as part of a module definition (even with just
8086
plain PyTorch).
8187

8288
.. testcode::
@@ -316,3 +322,98 @@ to_onehot (F)
316322

317323
.. autofunction:: pytorch_lightning.metrics.functional.to_onehot
318324
:noindex:
325+
326+
----------------
327+
328+
Sklearn interface
329+
-----------------
330+
331+
Lightning supports `sklearns metrics module <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
332+
as a backend for calculating metrics. Sklearns metrics are well tested and robust,
333+
but requires conversion between pytorch and numpy thus may slow down your computations.
334+
335+
To use the sklearn backend of metrics simply import as
336+
337+
.. code-block:: python
338+
339+
import pytorch_lightning.metrics.sklearns import plm
340+
metric = plm.Accuracy(normalize=True)
341+
val = metric(pred, target)
342+
343+
Each converted sklearn metric comes has the same interface as its
344+
originally counterpart (e.g. accuracy takes the additional `normalize` keyword).
345+
Like the native Lightning metrics these converted sklearn metrics also come
346+
with built-in distributed (ddp) support.
347+
348+
SklearnMetric (sk)
349+
^^^^^^^^^^^^^^^^^^
350+
351+
.. autofunction:: pytorch_lightning.metrics.sklearns.SklearnMetric
352+
:noindex:
353+
354+
Accuracy (sk)
355+
^^^^^^^^^^^^^
356+
357+
.. autofunction:: pytorch_lightning.metrics.sklearns.Accuracy
358+
:noindex:
359+
360+
AUC (sk)
361+
^^^^^^^^
362+
363+
.. autofunction:: pytorch_lightning.metrics.sklearns.AUC
364+
:noindex:
365+
366+
AveragePrecision (sk)
367+
^^^^^^^^^^^^^^^^^^^^^
368+
369+
.. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision
370+
:noindex:
371+
372+
373+
ConfusionMatrix (sk)
374+
^^^^^^^^^^^^^^^^^^^^
375+
376+
.. autofunction:: pytorch_lightning.metrics.sklearns.ConfusionMatrix
377+
:noindex:
378+
379+
F1 (sk)
380+
^^^^^^^
381+
382+
.. autofunction:: pytorch_lightning.metrics.sklearns.F1
383+
:noindex:
384+
385+
FBeta (sk)
386+
^^^^^^^^^^
387+
388+
.. autofunction:: pytorch_lightning.metrics.sklearns.FBeta
389+
:noindex:
390+
391+
Precision (sk)
392+
^^^^^^^^^^^^^^
393+
394+
.. autofunction:: pytorch_lightning.metrics.sklearns.Precision
395+
:noindex:
396+
397+
Recall (sk)
398+
^^^^^^^^^^^
399+
400+
.. autofunction:: pytorch_lightning.metrics.sklearns.Recall
401+
:noindex:
402+
403+
PrecisionRecallCurve (sk)
404+
^^^^^^^^^^^^^^^^^^^^^^^^^
405+
406+
.. autofunction:: pytorch_lightning.metrics.sklearns.PrecisionRecallCurve
407+
:noindex:
408+
409+
ROC (sk)
410+
^^^^^^^^
411+
412+
.. autofunction:: pytorch_lightning.metrics.sklearns.ROC
413+
:noindex:
414+
415+
AUROC (sk)
416+
^^^^^^^^^^
417+
418+
.. autofunction:: pytorch_lightning.metrics.sklearns.AUROC
419+
:noindex:

pytorch_lightning/metrics/__init__.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,41 @@
11
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
22
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
3-
from pytorch_lightning.metrics.sklearn import (
4-
SklearnMetric,
3+
from pytorch_lightning.metrics.classification import (
54
Accuracy,
65
AveragePrecision,
7-
AUC,
86
ConfusionMatrix,
97
F1,
108
FBeta,
11-
Precision,
129
Recall,
13-
PrecisionRecallCurve,
1410
ROC,
15-
AUROC)
11+
AUROC,
12+
DiceCoefficient,
13+
MulticlassPrecisionRecall,
14+
MulticlassROC,
15+
Precision,
16+
PrecisionRecall,
17+
)
18+
from pytorch_lightning.metrics.sklearns import (
19+
AUC,
20+
PrecisionRecallCurve,
21+
SklearnMetric,
22+
)
23+
24+
__all__ = [
25+
'AUC',
26+
'AUROC',
27+
'Accuracy',
28+
'AveragePrecision',
29+
'ConfusionMatrix',
30+
'DiceCoefficient',
31+
'F1',
32+
'FBeta',
33+
'MulticlassPrecisionRecall',
34+
'MulticlassROC',
35+
'Precision',
36+
'PrecisionRecall',
37+
'PrecisionRecallCurve',
38+
'ROC',
39+
'Recall',
40+
'SklearnMetric',
41+
]

0 commit comments

Comments
 (0)