|
1 |
| -import numpy as np |
2 |
| -import pytest |
3 |
| -import torch |
4 |
| -import torch.distributed as dist |
5 |
| - |
6 |
| -import tests.base.utils as tutils |
7 |
| -from pytorch_lightning.metrics.utils import _apply_to_inputs, _apply_to_outputs, \ |
8 |
| - _convert_to_tensor, _convert_to_numpy, _numpy_metric_conversion, \ |
9 |
| - _tensor_metric_conversion, _sync_ddp, tensor_metric, numpy_metric |
10 |
| - |
11 |
| - |
12 |
| -def test_apply_to_inputs(): |
13 |
| - def apply_fn(inputs, factor): |
14 |
| - if isinstance(inputs, (float, int)): |
15 |
| - return inputs * factor |
16 |
| - elif isinstance(inputs, dict): |
17 |
| - return {k: apply_fn(v, factor) for k, v in inputs.items()} |
18 |
| - elif isinstance(inputs, (tuple, list)): |
19 |
| - return [apply_fn(x, factor) for x in inputs] |
20 |
| - |
21 |
| - @_apply_to_inputs(apply_fn, factor=2.) |
22 |
| - def test_fn(*args, **kwargs): |
23 |
| - return args, kwargs |
24 |
| - |
25 |
| - for args in [[], [1., 2.]]: |
26 |
| - for kwargs in [{}, {1., 2.}]: |
27 |
| - result_args, result_kwargs = test_fn(*args, **kwargs) |
28 |
| - assert isinstance(result_args, list) |
29 |
| - assert isinstance(result_kwargs, dict) |
30 |
| - assert len(result_args) == len(args) |
31 |
| - assert len(result_kwargs) == len(kwargs) |
32 |
| - assert all([k in result_kwargs for k in kwargs.keys()]) |
33 |
| - for arg, result_arg in zip(args, result_args): |
34 |
| - assert arg * 2. == result_arg |
35 |
| - |
36 |
| - for key in kwargs.keys(): |
37 |
| - arg = kwargs[key], |
38 |
| - result_arg = result_kwargs[key] |
39 |
| - assert arg * 2. == result_arg |
40 |
| - |
41 |
| - |
42 |
| -def test_apply_to_outputs(): |
43 |
| - def apply_fn(inputs, additional_str): |
44 |
| - return str(inputs) + additional_str |
45 |
| - |
46 |
| - @_apply_to_outputs(apply_fn, additional_str='_str') |
47 |
| - def test_fn(*args, **kwargs): |
48 |
| - return 'dummy' |
49 |
| - |
50 |
| - assert test_fn() == 'dummy_str' |
51 |
| - |
52 |
| - |
53 |
| -def test_convert_to_tensor(): |
54 |
| - for test_item in [1., np.array([1.])]: |
55 |
| - assert isinstance(_convert_to_tensor(test_item), torch.Tensor) |
56 |
| - assert test_item.item() == 1. |
57 |
| - |
58 |
| - |
59 |
| -def test_convert_to_numpy(): |
60 |
| - for test_item in [1., torch.tensor([1.])]: |
61 |
| - result = _convert_to_numpy(test_item) |
62 |
| - assert isinstance(result, np.ndarray) |
63 |
| - assert result.item() == 1. |
64 |
| - |
65 |
| - |
66 |
| -def test_numpy_metric_conversion(): |
67 |
| - @_numpy_metric_conversion |
68 |
| - def numpy_test_metric(*args, **kwargs): |
69 |
| - for arg in args: |
70 |
| - assert isinstance(arg, np.ndarray) |
71 |
| - |
72 |
| - for v in kwargs.values(): |
73 |
| - assert isinstance(v, np.ndarray) |
74 |
| - |
75 |
| - return 5. |
76 |
| - |
77 |
| - result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) |
78 |
| - assert isinstance(result, torch.Tensor) |
79 |
| - assert result.item() == 5. |
80 |
| - |
81 |
| - |
82 |
| -def test_tensor_metric_conversion(): |
83 |
| - @_tensor_metric_conversion |
84 |
| - def tensor_test_metric(*args, **kwargs): |
85 |
| - for arg in args: |
86 |
| - assert isinstance(arg, torch.Tensor) |
87 |
| - |
88 |
| - for v in kwargs.values(): |
89 |
| - assert isinstance(v, torch.Tensor) |
90 |
| - |
91 |
| - return 5. |
92 |
| - |
93 |
| - result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) |
94 |
| - assert isinstance(result, torch.Tensor) |
95 |
| - assert result.item() == 5. |
96 |
| - |
97 |
| - |
98 |
| -@pytest.mark.skipif(torch.cuda.device_count() < 2, "test requires multi-GPU machine") |
99 |
| -def test_sync_reduce_ddp(): |
100 |
| - """Make sure sync-reduce works with DDP""" |
101 |
| - tutils.reset_seed() |
102 |
| - tutils.set_random_master_port() |
103 |
| - |
104 |
| - dist.init_process_group('gloo') |
105 |
| - |
106 |
| - tensor = torch.tensor([1.], device='cuda:0') |
107 |
| - |
108 |
| - reduced_tensor = _sync_ddp(tensor) |
109 |
| - |
110 |
| - assert reduced_tensor.item() == dist.get_world_size(), \ |
111 |
| - 'Sync-Reduce does not work properly with DDP and Tensors' |
112 |
| - |
113 |
| - number = 1. |
114 |
| - reduced_number = _sync_ddp(number) |
115 |
| - assert isinstance(reduced_number, torch.Tensor), 'When reducing a number we should get a tensor out' |
116 |
| - assert reduced_number.item() == dist.get_world_size(), \ |
117 |
| - 'Sync-Reduce does not work properly with DDP and Numbers' |
118 |
| - |
119 |
| - dist.destroy_process_group() |
120 |
| - |
121 |
| - |
122 |
| -def test_sync_reduce_simple(): |
123 |
| - """Make sure sync-reduce works without DDP""" |
124 |
| - tensor = torch.tensor([1.], device='cpu') |
125 |
| - |
126 |
| - reduced_tensor = _sync_ddp(tensor) |
127 |
| - |
128 |
| - assert torch.allclose(tensor, |
129 |
| - reduced_tensor), 'Sync-Reduce does not work properly without DDP and Tensors' |
130 |
| - |
131 |
| - number = 1. |
132 |
| - |
133 |
| - reduced_number = _sync_ddp(number) |
134 |
| - assert isinstance(reduced_number, torch.Tensor), 'When reducing a number we should get a tensor out' |
135 |
| - assert reduced_number.item() == number, 'Sync-Reduce does not work properly without DDP and Numbers' |
136 |
| - |
137 |
| - |
138 |
| -def _test_tensor_metric(is_ddp: bool): |
139 |
| - @tensor_metric() |
140 |
| - def tensor_test_metric(*args, **kwargs): |
141 |
| - for arg in args: |
142 |
| - assert isinstance(arg, torch.Tensor) |
143 |
| - |
144 |
| - for v in kwargs.values(): |
145 |
| - assert isinstance(v, torch.Tensor) |
146 |
| - |
147 |
| - return 5. |
148 |
| - |
149 |
| - if is_ddp: |
150 |
| - factor = dist.get_world_size() |
151 |
| - else: |
152 |
| - factor = 1. |
153 |
| - |
154 |
| - result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) |
155 |
| - assert isinstance(result, torch.Tensor) |
156 |
| - assert result.item() == 5. * factor |
157 |
| - |
158 |
| - |
159 |
| -@pytest.mark.skipif(torch.cuda.device_count() < 2, "test requires multi-GPU machine") |
160 |
| -def test_tensor_metric_ddp(): |
161 |
| - tutils.reset_seed() |
162 |
| - tutils.set_random_master_port() |
163 |
| - |
164 |
| - dist.init_process_group('gloo') |
165 |
| - _test_tensor_metric(True) |
166 |
| - dist.destroy_process_group() |
167 |
| - |
168 |
| - |
169 |
| -def test_tensor_metric_simple(): |
170 |
| - _test_tensor_metric(False) |
171 |
| - |
172 |
| - |
173 |
| -def _test_numpy_metric(is_ddp: bool): |
174 |
| - @numpy_metric() |
175 |
| - def numpy_test_metric(*args, **kwargs): |
176 |
| - for arg in args: |
177 |
| - assert isinstance(arg, np.ndarray) |
178 |
| - |
179 |
| - for v in kwargs.values(): |
180 |
| - assert isinstance(v, np.ndarray) |
181 |
| - |
182 |
| - return 5. |
183 |
| - |
184 |
| - if is_ddp: |
185 |
| - factor = dist.get_world_size() |
186 |
| - else: |
187 |
| - factor = 1. |
188 |
| - |
189 |
| - result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) |
190 |
| - assert isinstance(result, torch.Tensor) |
191 |
| - assert result.item() == 5. * factor |
192 |
| - |
193 |
| - |
194 |
| -@pytest.mark.skipif(torch.cuda.device_count() < 2, "test requires multi-GPU machine") |
195 |
| -def test_numpy_metric_ddp(): |
196 |
| - tutils.reset_seed() |
197 |
| - tutils.set_random_master_port() |
198 |
| - |
199 |
| - dist.init_process_group('gloo') |
200 |
| - _test_tensor_metric(True) |
201 |
| - dist.destroy_process_group() |
202 |
| - |
203 |
| - |
204 |
| -def test_numpy_metric_simple(): |
205 |
| - _test_tensor_metric(False) |
0 commit comments