-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Base class for metrics #1293
Changes from all commits
348ad16
0f86ea1
bbf76eb
ffb4490
2154af6
81f5fc3
806ddef
b4c0250
5583844
9f892a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
""" | ||
Metrics | ||
======= | ||
|
||
TODO | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import numbers | ||
from abc import ABC, abstractmethod | ||
from collections import Mapping, Sequence | ||
from functools import partial | ||
from typing import Union, Any, Optional | ||
|
||
import torch | ||
import torch.distributed | ||
from torch.utils.data._utils.collate import np_str_obj_array_pattern | ||
|
||
__all__ = ['BaseMetric'] | ||
|
||
|
||
class BaseMetric(torch.nn.Module, ABC): | ||
def __init__(self, name: str, | ||
reduce_group: Optional[Any] = torch.distributed.group.WORLD, | ||
reduce_op: Optional[Any] = torch.distributed.ReduceOp.SUM): | ||
""" | ||
Abstract Base Class for metric implementation. | ||
|
||
Automatically handles the computation | ||
Args: | ||
name: the metric's name | ||
reduce_group: the process group for DDP reduces (only needed for DDP training). | ||
Defaults to all processes (world) | ||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). | ||
Defaults to sum. | ||
""" | ||
super().__init__() | ||
self.name = name | ||
self.reduce_op = reduce_op | ||
self.reduce_group = reduce_group | ||
|
||
@abstractmethod | ||
def forward(self, *args, **kwargs) -> torch.Tensor: | ||
""" | ||
Implements the actual metric computation. | ||
|
||
Returns: | ||
metric value | ||
|
||
""" | ||
pass | ||
|
||
def __call__(self, *args, **kwargs) -> torch.Tensor: | ||
return _sync_collections(super().__call__(*args, **kwargs), | ||
group=self.reduce_group, | ||
reduce_op=self.reduce_op) | ||
|
||
|
||
def _sync_ddp(result: Union[torch.Tensor, numbers.Number], | ||
group: Any = torch.distributed.group.WORLD, | ||
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM) -> torch.Tensor: | ||
""" | ||
Function to reduce the tensors from several ddp processes to one master process | ||
|
||
Args: | ||
result: the value to sync and reduce (typically tensor or number) | ||
group: the process group to gather results from. Defaults to all processes (world) | ||
reduce_op: the reduction operation. Defaults to sum | ||
|
||
Returns: | ||
reduced value | ||
|
||
""" | ||
|
||
# convert to tensor if necessary | ||
if not isinstance(result, torch.Tensor): | ||
result = torch.tensor(result) | ||
|
||
if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
# sync all processes before reduction | ||
torch.distributed.barrier(group=group) | ||
torch.distributed.all_reduce(result, op=reduce_op, group=group, | ||
async_op=False) | ||
|
||
return result | ||
|
||
|
||
def _sync_collections(result: Union[torch.Tensor, numbers.Number, | ||
Mapping, Sequence], | ||
group: Any = torch.distributed.group.WORLD, | ||
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM | ||
) -> Union[torch.Tensor, numbers.Number, | ||
Mapping, Sequence]: | ||
""" | ||
Recursively applies sync_ddp to collections | ||
|
||
Args: | ||
result: Tensor or Number or Mapping or Sequence holding the values to be reduced | ||
group: the process group to gather results from. Defaults to all processes (world) | ||
reduce_op: the reduction operation. Defaults to sum | ||
|
||
Returns: | ||
the reduced collection | ||
|
||
""" | ||
# function adapted from torch.utils.data._utils.collate | ||
elem_type = type(result) | ||
|
||
func = partial(_sync_collections, group=group, reduce_op=reduce_op) | ||
|
||
# convert numpy to tensor if possible | ||
if elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | ||
and elem_type.__name__ != 'string_': | ||
# array not of string classes and object | ||
if elem_type.__name__ != 'ndarray' \ | ||
or np_str_obj_array_pattern.search(result.dtype.str) is None: | ||
result = torch.as_tensor(result) | ||
|
||
if isinstance(result, (torch.Tensor, numbers.Number)): | ||
return _sync_ddp(result, group=group, reduce_op=reduce_op) | ||
|
||
elif isinstance(result, Mapping): | ||
return elem_type({key: func(result[key]) for key in result}) | ||
elif isinstance(result, tuple) and hasattr(result, '_fields'): # namedtuple | ||
return elem_type(*(func(r) for r in result)) | ||
elif isinstance(result, Sequence) and not isinstance(result, str): | ||
return elem_type([func(r) for r in result]) | ||
# not possible to reduce this type | ||
else: | ||
return result | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not safe since all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be safe, since we just return the input variable as is, if there is nothing to sync... |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from collections import namedtuple | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
import torch.distributed as dist | ||
|
||
import tests.base.utils as tutils | ||
from pytorch_lightning.metrics.metric import _sync_ddp, _sync_collections, BaseMetric | ||
|
||
|
||
@pytest.mark.skipif(torch.cuda.device_count() < 2, "test requires multi-GPU machine") | ||
def test_sync_reduce_ddp(): | ||
"""Make sure sync-reduce works with DDP""" | ||
tutils.reset_seed() | ||
tutils.set_random_master_port() | ||
|
||
dist.init_process_group('gloo') | ||
|
||
tensor = torch.tensor([1.], device='cuda:0') | ||
|
||
reduced_tensor = _sync_ddp(tensor) | ||
assert reduced_tensor.item() == dist.get_world_size(), \ | ||
'Sync-Reduce does not work properly with DDP and Tensors' | ||
|
||
number = 1. | ||
reduced_number = _sync_ddp(number) | ||
assert isinstance(reduced_number, torch.Tensor), 'When reducing a number we should get a tensor out' | ||
assert reduced_number.item() == dist.get_world_size(), \ | ||
'Sync-Reduce does not work properly with DDP and Numbers' | ||
|
||
dist.destroy_process_group() | ||
|
||
|
||
def test_sync_reduce_simple(): | ||
"""Make sure sync-reduce works without DDP""" | ||
tensor = torch.tensor([1.], device='cpu') | ||
|
||
reduced_tensor = _sync_ddp(tensor) | ||
|
||
assert torch.allclose(tensor, | ||
reduced_tensor), 'Sync-Reduce does not work properly without DDP and Tensors' | ||
|
||
number = 1. | ||
reduced_number = _sync_ddp(number) | ||
assert isinstance(reduced_number, torch.Tensor), 'When reducing a number we should get a tensor out' | ||
assert reduced_number.item() == number, 'Sync-Reduce does not work properly without DDP and Numbers' | ||
|
||
|
||
def _sync_collections_test(is_ddp: bool): | ||
ntc = namedtuple('Foo', ['bar']) | ||
to_reduce = { | ||
'a': torch.tensor([1.]), # Tensor | ||
'b': [torch.tensor([2.])], # list | ||
'c': (torch.tensor([100.]),), # tuple | ||
'd': ntc(bar=5.), # named tuple | ||
'e': np.array([10.]), # numpy array | ||
'f': 'this_is_a_dummy_str', # string | ||
'g': 12. # number | ||
} | ||
|
||
if is_ddp: | ||
factor = dist.get_world_size() | ||
else: | ||
factor = 1. | ||
|
||
expected_result = { | ||
'a': torch.tensor([1. * factor]), | ||
'b': [torch.tensor([2. * factor])], | ||
'c': (torch.tensor([100. * factor]),), | ||
'd': ntc(bar=torch.tensor([5. * factor])), | ||
'e': torch.tensor([10. * factor]), | ||
'f': 'this_is_a_dummy_str', | ||
'g': torch.tensor([12. * factor]), | ||
} | ||
|
||
reduced = _sync_collections(to_reduce) | ||
|
||
assert isinstance(reduced, dict), ' Type Consistency of dict not preserved' | ||
assert all([x in reduced for x in to_reduce.keys()]), 'Not all entries of the dict were preserved' | ||
assert all([isinstance(reduced[k], type(expected_result[k])) for k in to_reduce.keys()]), \ | ||
'At least one type was not correctly preserved' | ||
|
||
assert isinstance(reduced['a'], torch.Tensor), 'Reduction Result of a Tensor should be a Tensor' | ||
assert torch.allclose(expected_result['a'], | ||
reduced['a']), 'Reduction of a tensor does not yield the expected value' | ||
|
||
assert isinstance(reduced['b'], list), 'Reduction Result of a list should be a list' | ||
assert all([torch.allclose(x, y) for x, y in zip(reduced['b'], expected_result['b'])]), \ | ||
'At least one value of list reduction did not come out as expected' | ||
|
||
assert isinstance(reduced['c'], tuple), 'Reduction Result of a tuple should be a tuple' | ||
assert all([torch.allclose(x, y) for x, y in zip(reduced['c'], expected_result['c'])]), \ | ||
'At least one value of tuple reduction did not come out as expected' | ||
|
||
assert isinstance(reduced['d'], ntc), 'Type Consistency for named tuple not given' | ||
assert isinstance(reduced['d'].bar, | ||
torch.Tensor), 'Failure in type promotion while reducing fields of named tuples' | ||
assert torch.allclose(reduced['d'].bar, expected_result['d'].bar) | ||
|
||
assert isinstance(reduced['e'], torch.Tensor), 'Type Promotion in reduction of numpy arrays failed' | ||
assert torch.allclose(reduced['e'], expected_result['e']), \ | ||
'Reduction of numpy array did not yield the expected result' | ||
|
||
assert isinstance(reduced['f'], str), 'A string should not be reduced' | ||
assert reduced['f'] == expected_result['f'], 'String not preserved during reduction' | ||
|
||
assert isinstance(reduced['g'], torch.Tensor), 'Reduction of a number should result in a tensor' | ||
assert torch.allclose(reduced['g'], | ||
expected_result['g']), 'Reduction of a number did not yield the desired result' | ||
|
||
|
||
@pytest.mark.skipif(torch.cuda.device_count() < 2, | ||
'Not enough GPUs to test sync reduce') | ||
def test_sync_collections_ddp(): | ||
tutils.reset_seed() | ||
tutils.set_random_master_port() | ||
|
||
dist.init_process_group('gloo') | ||
|
||
_sync_collections_test(True) | ||
|
||
dist.destroy_process_group() | ||
|
||
|
||
def test_sync_collections_simple(): | ||
_sync_collections_test(False) | ||
|
||
|
||
def _test_base_metric(is_ddp): | ||
class DummyMetric(BaseMetric): | ||
def __init__(self): | ||
super().__init__(name='Dummy') | ||
|
||
def forward(self): | ||
return 1. | ||
|
||
dummy_metric = DummyMetric() | ||
|
||
assert dummy_metric.name == 'Dummy' | ||
metric_val = dummy_metric() | ||
|
||
if is_ddp: | ||
expected = dist.get_world_size() | ||
else: | ||
expected = 1. | ||
|
||
assert isinstance(metric_val, torch.Tensor), \ | ||
'The result value should be synced and reduced which would promote the type from number to tensor' | ||
assert metric_val.item() == expected, 'Invalid Value for reduction' | ||
|
||
|
||
@pytest.mark.skipif(torch.cuda.device_count() < 2, | ||
'Not enough GPUs to test with ddp') | ||
def test_base_metric_ddp(): | ||
_test_base_metric(True) | ||
|
||
|
||
def test_base_metric_simple(): | ||
_test_base_metric(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inst mean more frequently used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mean is not available as a reduction op :)
https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp