Skip to content

Commit f53c676

Browse files
justusschockBorda
andcommitted
New metric classes (#1326)
* Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <[email protected]> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <[email protected]> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <[email protected]> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <[email protected]> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <[email protected]> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <[email protected]> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <[email protected]> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <[email protected]> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <[email protected]> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <[email protected]> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <[email protected]> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <[email protected]>
1 parent 1f2ed9d commit f53c676

File tree

2 files changed

+3
-205
lines changed

2 files changed

+3
-205
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
## Metrics (will be added to unreleased once the metric branch was finished)
8+
- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326))
9+
710
## [unreleased] - YYYY-MM-DD
811

912
### Added

tests/metrics/__init__.py

-205
Original file line numberDiff line numberDiff line change
@@ -1,205 +0,0 @@
1-
import numpy as np
2-
import pytest
3-
import torch
4-
import torch.distributed as dist
5-
6-
import tests.base.utils as tutils
7-
from pytorch_lightning.metrics.utils import _apply_to_inputs, _apply_to_outputs, \
8-
_convert_to_tensor, _convert_to_numpy, _numpy_metric_conversion, \
9-
_tensor_metric_conversion, _sync_ddp, tensor_metric, numpy_metric
10-
11-
12-
def test_apply_to_inputs():
13-
def apply_fn(inputs, factor):
14-
if isinstance(inputs, (float, int)):
15-
return inputs * factor
16-
elif isinstance(inputs, dict):
17-
return {k: apply_fn(v, factor) for k, v in inputs.items()}
18-
elif isinstance(inputs, (tuple, list)):
19-
return [apply_fn(x, factor) for x in inputs]
20-
21-
@_apply_to_inputs(apply_fn, factor=2.)
22-
def test_fn(*args, **kwargs):
23-
return args, kwargs
24-
25-
for args in [[], [1., 2.]]:
26-
for kwargs in [{}, {1., 2.}]:
27-
result_args, result_kwargs = test_fn(*args, **kwargs)
28-
assert isinstance(result_args, list)
29-
assert isinstance(result_kwargs, dict)
30-
assert len(result_args) == len(args)
31-
assert len(result_kwargs) == len(kwargs)
32-
assert all([k in result_kwargs for k in kwargs.keys()])
33-
for arg, result_arg in zip(args, result_args):
34-
assert arg * 2. == result_arg
35-
36-
for key in kwargs.keys():
37-
arg = kwargs[key],
38-
result_arg = result_kwargs[key]
39-
assert arg * 2. == result_arg
40-
41-
42-
def test_apply_to_outputs():
43-
def apply_fn(inputs, additional_str):
44-
return str(inputs) + additional_str
45-
46-
@_apply_to_outputs(apply_fn, additional_str='_str')
47-
def test_fn(*args, **kwargs):
48-
return 'dummy'
49-
50-
assert test_fn() == 'dummy_str'
51-
52-
53-
def test_convert_to_tensor():
54-
for test_item in [1., np.array([1.])]:
55-
assert isinstance(_convert_to_tensor(test_item), torch.Tensor)
56-
assert test_item.item() == 1.
57-
58-
59-
def test_convert_to_numpy():
60-
for test_item in [1., torch.tensor([1.])]:
61-
result = _convert_to_numpy(test_item)
62-
assert isinstance(result, np.ndarray)
63-
assert result.item() == 1.
64-
65-
66-
def test_numpy_metric_conversion():
67-
@_numpy_metric_conversion
68-
def numpy_test_metric(*args, **kwargs):
69-
for arg in args:
70-
assert isinstance(arg, np.ndarray)
71-
72-
for v in kwargs.values():
73-
assert isinstance(v, np.ndarray)
74-
75-
return 5.
76-
77-
result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.)
78-
assert isinstance(result, torch.Tensor)
79-
assert result.item() == 5.
80-
81-
82-
def test_tensor_metric_conversion():
83-
@_tensor_metric_conversion
84-
def tensor_test_metric(*args, **kwargs):
85-
for arg in args:
86-
assert isinstance(arg, torch.Tensor)
87-
88-
for v in kwargs.values():
89-
assert isinstance(v, torch.Tensor)
90-
91-
return 5.
92-
93-
result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.)
94-
assert isinstance(result, torch.Tensor)
95-
assert result.item() == 5.
96-
97-
98-
@pytest.mark.skipif(torch.cuda.device_count() < 2, "test requires multi-GPU machine")
99-
def test_sync_reduce_ddp():
100-
"""Make sure sync-reduce works with DDP"""
101-
tutils.reset_seed()
102-
tutils.set_random_master_port()
103-
104-
dist.init_process_group('gloo')
105-
106-
tensor = torch.tensor([1.], device='cuda:0')
107-
108-
reduced_tensor = _sync_ddp(tensor)
109-
110-
assert reduced_tensor.item() == dist.get_world_size(), \
111-
'Sync-Reduce does not work properly with DDP and Tensors'
112-
113-
number = 1.
114-
reduced_number = _sync_ddp(number)
115-
assert isinstance(reduced_number, torch.Tensor), 'When reducing a number we should get a tensor out'
116-
assert reduced_number.item() == dist.get_world_size(), \
117-
'Sync-Reduce does not work properly with DDP and Numbers'
118-
119-
dist.destroy_process_group()
120-
121-
122-
def test_sync_reduce_simple():
123-
"""Make sure sync-reduce works without DDP"""
124-
tensor = torch.tensor([1.], device='cpu')
125-
126-
reduced_tensor = _sync_ddp(tensor)
127-
128-
assert torch.allclose(tensor,
129-
reduced_tensor), 'Sync-Reduce does not work properly without DDP and Tensors'
130-
131-
number = 1.
132-
133-
reduced_number = _sync_ddp(number)
134-
assert isinstance(reduced_number, torch.Tensor), 'When reducing a number we should get a tensor out'
135-
assert reduced_number.item() == number, 'Sync-Reduce does not work properly without DDP and Numbers'
136-
137-
138-
def _test_tensor_metric(is_ddp: bool):
139-
@tensor_metric()
140-
def tensor_test_metric(*args, **kwargs):
141-
for arg in args:
142-
assert isinstance(arg, torch.Tensor)
143-
144-
for v in kwargs.values():
145-
assert isinstance(v, torch.Tensor)
146-
147-
return 5.
148-
149-
if is_ddp:
150-
factor = dist.get_world_size()
151-
else:
152-
factor = 1.
153-
154-
result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.)
155-
assert isinstance(result, torch.Tensor)
156-
assert result.item() == 5. * factor
157-
158-
159-
@pytest.mark.skipif(torch.cuda.device_count() < 2, "test requires multi-GPU machine")
160-
def test_tensor_metric_ddp():
161-
tutils.reset_seed()
162-
tutils.set_random_master_port()
163-
164-
dist.init_process_group('gloo')
165-
_test_tensor_metric(True)
166-
dist.destroy_process_group()
167-
168-
169-
def test_tensor_metric_simple():
170-
_test_tensor_metric(False)
171-
172-
173-
def _test_numpy_metric(is_ddp: bool):
174-
@numpy_metric()
175-
def numpy_test_metric(*args, **kwargs):
176-
for arg in args:
177-
assert isinstance(arg, np.ndarray)
178-
179-
for v in kwargs.values():
180-
assert isinstance(v, np.ndarray)
181-
182-
return 5.
183-
184-
if is_ddp:
185-
factor = dist.get_world_size()
186-
else:
187-
factor = 1.
188-
189-
result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.)
190-
assert isinstance(result, torch.Tensor)
191-
assert result.item() == 5. * factor
192-
193-
194-
@pytest.mark.skipif(torch.cuda.device_count() < 2, "test requires multi-GPU machine")
195-
def test_numpy_metric_ddp():
196-
tutils.reset_seed()
197-
tutils.set_random_master_port()
198-
199-
dist.init_process_group('gloo')
200-
_test_tensor_metric(True)
201-
dist.destroy_process_group()
202-
203-
204-
def test_numpy_metric_simple():
205-
_test_tensor_metric(False)

0 commit comments

Comments
 (0)