Skip to content

Commit 002b595

Browse files
vfdev-5n2cholassdesrozisn2cholas
authored
[BC-breaking] Make Metrics accumulate values on device specified by user (#1232) (#1238)
* Make Metrics accumulate values on device specified by user (#1232) * update accuracy to accumulate _num_correct in a tensor on the right device * update loss metric to accumulate _sum in a tensor on the right device * update mae metric to accumulate in a tensor on the right device * update mpd metric to accumulate in a tensor on the right device * update mse metric to accumulate in a tensor on the right device * update top k accuracy metric to accumulate in a tensor on the right device * update precision and recall metrics to accumulate in tensors on the right device * ..... * black formatting * reverted run*.sh * change all metrics default device to cpu except running_average * Update ignite/metrics/precision.py Co-authored-by: vfdev <[email protected]> * remove Optional type from metric devices since default is cpu * add comment explaining lack of detach in accuracy metrics Co-authored-by: vfdev <[email protected]> * Improved and fixed accuracy tests * autopep8 fix * update docs and docstrings for updated metrics (#1239) * update accuracy to accumulate _num_correct in a tensor on the right device * update loss metric to accumulate _sum in a tensor on the right device * update mae metric to accumulate in a tensor on the right device * update mpd metric to accumulate in a tensor on the right device * update mse metric to accumulate in a tensor on the right device * update top k accuracy metric to accumulate in a tensor on the right device * update precision and recall metrics to accumulate in tensors on the right device * ..... * black formatting * reverted run*.sh * change all metrics default device to cpu except running_average * Update ignite/metrics/precision.py Co-authored-by: vfdev <[email protected]> * remove Optional type from metric devices since default is cpu * add comment explaining lack of detach in accuracy metrics * update docstrings and docs * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accuracy.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/fbeta.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/loss.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/metric.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/precision.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/recall.py Co-authored-by: vfdev <[email protected]> * add comment explaining lack of detach in metrics docs * support device argument for running_average * update support for device argumenet for accumulation * fix and improve device tests for metrics * fix and improve device tests for metrics * fix TPU tests * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: vfdev <[email protected]> * Updates to metrics_impl (#1266) * update accuracy to accumulate _num_correct in a tensor on the right device * update loss metric to accumulate _sum in a tensor on the right device * update mae metric to accumulate in a tensor on the right device * update mpd metric to accumulate in a tensor on the right device * update mse metric to accumulate in a tensor on the right device * update top k accuracy metric to accumulate in a tensor on the right device * update precision and recall metrics to accumulate in tensors on the right device * ..... * black formatting * reverted run*.sh * change all metrics default device to cpu except running_average * Update ignite/metrics/precision.py Co-authored-by: vfdev <[email protected]> * remove Optional type from metric devices since default is cpu * add comment explaining lack of detach in accuracy metrics * update docstrings and docs * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accuracy.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/fbeta.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/loss.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/metric.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/precision.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/recall.py Co-authored-by: vfdev <[email protected]> * add comment explaining lack of detach in metrics docs * support device argument for running_average * update support for device argumenet for accumulation * fix and improve device tests for metrics * fix and improve device tests for metrics * fix TPU tests * Apply suggestions from code review * Apply suggestions from code review * detach tensors earlier in update * remove redundant to() call * ensure metrics aren't created on XLA devices * Fixed isort * move xla check to Metric.__init__ instead of individual metrics * update xla tests * replace deleted callable check * remove redundant precision and recall __init__ * replace precision/recall __init__ for docs rendering * add support for metrics_lambda with components on diff devices Co-authored-by: vfdev <[email protected]> Co-authored-by: n2cholas <[email protected]> * Update metrics.rst * Update metrics.rst * Fix TPU tests for metrics_impl branch (#1277) * update accuracy to accumulate _num_correct in a tensor on the right device * update loss metric to accumulate _sum in a tensor on the right device * update mae metric to accumulate in a tensor on the right device * update mpd metric to accumulate in a tensor on the right device * update mse metric to accumulate in a tensor on the right device * update top k accuracy metric to accumulate in a tensor on the right device * update precision and recall metrics to accumulate in tensors on the right device * ..... * black formatting * reverted run*.sh * change all metrics default device to cpu except running_average * Update ignite/metrics/precision.py Co-authored-by: vfdev <[email protected]> * remove Optional type from metric devices since default is cpu * add comment explaining lack of detach in accuracy metrics * update docstrings and docs * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accuracy.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/fbeta.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/loss.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/metric.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/precision.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/recall.py Co-authored-by: vfdev <[email protected]> * add comment explaining lack of detach in metrics docs * support device argument for running_average * update support for device argumenet for accumulation * fix and improve device tests for metrics * fix and improve device tests for metrics * fix TPU tests * Apply suggestions from code review * Apply suggestions from code review * detach tensors earlier in update * remove redundant to() call * ensure metrics aren't created on XLA devices * Fixed isort * move xla check to Metric.__init__ instead of individual metrics * update xla tests * replace deleted callable check * remove redundant precision and recall __init__ * replace precision/recall __init__ for docs rendering * add support for metrics_lambda with components on diff devices * fix epoch_metric xla test Co-authored-by: vfdev <[email protected]> Co-authored-by: n2cholas <[email protected]> * metrics_impl fix 2 gpu hvd tests and ensure consistent detaching (#1280) * update accuracy to accumulate _num_correct in a tensor on the right device * update loss metric to accumulate _sum in a tensor on the right device * update mae metric to accumulate in a tensor on the right device * update mpd metric to accumulate in a tensor on the right device * update mse metric to accumulate in a tensor on the right device * update top k accuracy metric to accumulate in a tensor on the right device * update precision and recall metrics to accumulate in tensors on the right device * ..... * black formatting * reverted run*.sh * change all metrics default device to cpu except running_average * Update ignite/metrics/precision.py Co-authored-by: vfdev <[email protected]> * remove Optional type from metric devices since default is cpu * add comment explaining lack of detach in accuracy metrics * update docstrings and docs * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accumulation.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/accuracy.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/fbeta.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/loss.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/metric.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/precision.py Co-authored-by: vfdev <[email protected]> * Update ignite/metrics/recall.py Co-authored-by: vfdev <[email protected]> * add comment explaining lack of detach in metrics docs * support device argument for running_average * update support for device argumenet for accumulation * fix and improve device tests for metrics * fix and improve device tests for metrics * fix TPU tests * Apply suggestions from code review * Apply suggestions from code review * detach tensors earlier in update * remove redundant to() call * ensure metrics aren't created on XLA devices * Fixed isort * move xla check to Metric.__init__ instead of individual metrics * update xla tests * replace deleted callable check * remove redundant precision and recall __init__ * replace precision/recall __init__ for docs rendering * add support for metrics_lambda with components on diff devices * fix epoch_metric xla test * detach output consistently for all metrics * fix horovod two gpu tests * make confusion matrix detaches like other metrics Co-authored-by: vfdev <[email protected]> Co-authored-by: n2cholas <[email protected]> * Fixes failing test on TPUs Co-authored-by: Nicholas Vadivelu <[email protected]> Co-authored-by: AutoPEP8 <> Co-authored-by: Sylvain Desroziers <[email protected]> Co-authored-by: n2cholas <[email protected]>
1 parent d92f1c6 commit 002b595

33 files changed

+1134
-458
lines changed

docs/source/metrics.rst

+10-6
Original file line numberDiff line numberDiff line change
@@ -120,21 +120,21 @@ specific condition (e.g. ignore user-defined classes):
120120
121121
class CustomAccuracy(Metric):
122122
123-
def __init__(self, ignored_class, output_transform=lambda x: x):
123+
def __init__(self, ignored_class, output_transform=lambda x: x, device="cpu"):
124124
self.ignored_class = ignored_class
125125
self._num_correct = None
126126
self._num_examples = None
127-
super(CustomAccuracy, self).__init__(output_transform=output_transform)
127+
super(CustomAccuracy, self).__init__(output_transform=output_transform, device=device)
128128
129129
@reinit__is_reduced
130130
def reset(self):
131-
self._num_correct = 0
131+
self._num_correct = torch.tensor(0, device=self._device)
132132
self._num_examples = 0
133133
super(CustomAccuracy, self).reset()
134134
135135
@reinit__is_reduced
136136
def update(self, output):
137-
y_pred, y = output
137+
y_pred, y = output[0].detach(), output[1].detach()
138138
139139
indices = torch.argmax(y_pred, dim=1)
140140
@@ -144,21 +144,25 @@ specific condition (e.g. ignore user-defined classes):
144144
indices = indices[mask]
145145
correct = torch.eq(indices, y).view(-1)
146146
147-
self._num_correct += torch.sum(correct).item()
147+
self._num_correct += torch.sum(correct).to(self._device)
148148
self._num_examples += correct.shape[0]
149149
150150
@sync_all_reduce("_num_examples", "_num_correct")
151151
def compute(self):
152152
if self._num_examples == 0:
153153
raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.')
154-
return self._num_correct / self._num_examples
154+
return self._num_correct.item() / self._num_examples
155155
156156
157157
We imported necessary classes as :class:`~ignite.metrics.Metric`, :class:`~ignite.exceptions.NotComputableError` and
158158
decorators to adapt the metric for distributed setting. In ``reset`` method, we reset internal variables ``_num_correct``
159159
and ``_num_examples`` which are used to compute the custom metric. In ``updated`` method we define how to update
160160
the internal variables. And finally in ``compute`` method, we compute metric value.
161161

162+
Notice that ``_num_correct`` is a tensor, since in ``update`` we accumulate tensor values. ``_num_examples`` is a python
163+
scalar since we accumulate normal integers. For differentiable metrics, you must detach the accumulated values before
164+
adding them to the internal variables.
165+
162166
We can check this implementation in a simple case:
163167

164168
.. code-block:: python

ignite/metrics/accumulation.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numbers
2-
from typing import Any, Callable, Optional, Union
2+
from typing import Any, Callable, Union
33

44
import torch
55

@@ -31,14 +31,19 @@ class VariableAccumulation(Metric):
3131
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
3232
form expected by the metric. This can be useful if, for example, you have a multi-output model and
3333
you want to compute the metric with respect to one of the outputs.
34-
device (str of torch.device, optional): optional device specification for internal storage.
34+
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
35+
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
36+
default, CPU.
3537
3638
"""
3739

3840
_required_output_keys = None
3941

4042
def __init__(
41-
self, op: Callable, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None
43+
self,
44+
op: Callable,
45+
output_transform: Callable = lambda x: x,
46+
device: Union[str, torch.device] = torch.device("cpu"),
4247
):
4348
if not callable(op):
4449
raise TypeError("Argument op should be a callable, but given {}".format(type(op)))
@@ -61,12 +66,13 @@ def _check_output_type(self, output: Union[Any, torch.Tensor, numbers.Number]) -
6166
def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None:
6267
self._check_output_type(output)
6368

64-
if self._device is not None:
65-
# Put output to the metric's device
66-
if isinstance(output, torch.Tensor) and (output.device != self._device):
69+
if isinstance(output, torch.Tensor):
70+
output = output.detach()
71+
if output.device != self._device:
6772
output = output.to(self._device)
6873

6974
self.accumulator = self._op(self.accumulator, output)
75+
7076
if hasattr(output, "shape"):
7177
self.num_examples += output.shape[0] if len(output.shape) > 1 else 1
7278
else:
@@ -111,11 +117,14 @@ class Average(VariableAccumulation):
111117
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
112118
form expected by the metric. This can be useful if, for example, you have a multi-output model and
113119
you want to compute the metric with respect to one of the outputs.
114-
device (str of torch.device, optional): optional device specification for internal storage.
115-
120+
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
121+
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
122+
default, CPU.
116123
"""
117124

118-
def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None):
125+
def __init__(
126+
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
127+
):
119128
def _mean_op(a, x):
120129
if isinstance(x, torch.Tensor) and x.ndim > 1:
121130
x = x.sum(dim=0)
@@ -155,11 +164,15 @@ class GeometricAverage(VariableAccumulation):
155164
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
156165
form expected by the metric. This can be useful if, for example, you have a multi-output model and
157166
you want to compute the metric with respect to one of the outputs.
158-
device (str of torch.device, optional): optional device specification for internal storage.
167+
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
168+
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
169+
default, CPU.
159170
160171
"""
161172

162-
def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None):
173+
def __init__(
174+
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
175+
):
163176
def _geom_op(a: torch.Tensor, x: Union[Any, numbers.Number, torch.Tensor]) -> torch.Tensor:
164177
if not isinstance(x, torch.Tensor):
165178
x = torch.tensor(x)

ignite/metrics/accuracy.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional, Sequence, Union
1+
from typing import Callable, Sequence, Union
22

33
import torch
44

@@ -13,7 +13,7 @@ def __init__(
1313
self,
1414
output_transform: Callable = lambda x: x,
1515
is_multilabel: bool = False,
16-
device: Optional[Union[str, torch.device]] = None,
16+
device: Union[str, torch.device] = torch.device("cpu"),
1717
):
1818
self._is_multilabel = is_multilabel
1919
self._type = None
@@ -122,31 +122,33 @@ def thresholded_output_transform(output):
122122
form expected by the metric. This can be useful if, for example, you have a multi-output model and
123123
you want to compute the metric with respect to one of the outputs.
124124
is_multilabel (bool, optional): flag to use in multilabel case. By default, False.
125-
device (str of torch.device, optional): unused argument.
125+
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
126+
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
127+
default, CPU.
126128
127129
"""
128130

129131
def __init__(
130132
self,
131133
output_transform: Callable = lambda x: x,
132134
is_multilabel: bool = False,
133-
device: Optional[Union[str, torch.device]] = None,
135+
device: Union[str, torch.device] = torch.device("cpu"),
134136
):
135137
self._num_correct = None
136138
self._num_examples = None
137139
super(Accuracy, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device)
138140

139141
@reinit__is_reduced
140142
def reset(self) -> None:
141-
self._num_correct = 0
143+
self._num_correct = torch.tensor(0, device=self._device)
142144
self._num_examples = 0
143145
super(Accuracy, self).reset()
144146

145147
@reinit__is_reduced
146148
def update(self, output: Sequence[torch.Tensor]) -> None:
147-
y_pred, y = output
148-
self._check_shape((y_pred, y))
149-
self._check_type((y_pred, y))
149+
self._check_shape(output)
150+
self._check_type(output)
151+
y_pred, y = output[0].detach(), output[1].detach()
150152

151153
if self._type == "binary":
152154
correct = torch.eq(y_pred.view(-1).to(y), y.view(-1))
@@ -161,11 +163,11 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
161163
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
162164
correct = torch.all(y == y_pred.type_as(y), dim=-1)
163165

164-
self._num_correct += torch.sum(correct).item()
166+
self._num_correct += torch.sum(correct).to(self._device)
165167
self._num_examples += correct.shape[0]
166168

167169
@sync_all_reduce("_num_examples", "_num_correct")
168170
def compute(self) -> torch.Tensor:
169171
if self._num_examples == 0:
170172
raise NotComputableError("Accuracy must have at least one example before it can be computed.")
171-
return self._num_correct / self._num_examples
173+
return self._num_correct.item() / self._num_examples

ignite/metrics/confusion_matrix.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ class ConfusionMatrix(Metric):
3030
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
3131
form expected by the metric. This can be useful if, for example, you have a multi-output model and
3232
you want to compute the metric with respect to one of the outputs.
33-
device (str of torch.device, optional): optional device specification for internal storage.
33+
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
34+
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
35+
default, CPU.
3436
3537
Note:
3638
In case of the targets `y` in `(batch_size, ...)` format, target indices between 0 and `num_classes` only
@@ -44,7 +46,7 @@ def __init__(
4446
num_classes: int,
4547
average: Optional[str] = None,
4648
output_transform: Callable = lambda x: x,
47-
device: Optional[Union[str, torch.device]] = None,
49+
device: Union[str, torch.device] = torch.device("cpu"),
4850
):
4951
if average is not None and average not in ("samples", "recall", "precision"):
5052
raise ValueError("Argument average can None or one of 'samples', 'recall', 'precision'")
@@ -61,7 +63,7 @@ def reset(self) -> None:
6163
self._num_examples = 0
6264

6365
def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
64-
y_pred, y = output
66+
y_pred, y = output[0].detach(), output[1].detach()
6567

6668
if y_pred.ndimension() < 2:
6769
raise ValueError(
@@ -92,7 +94,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
9294
@reinit__is_reduced
9395
def update(self, output: Sequence[torch.Tensor]) -> None:
9496
self._check_shape(output)
95-
y_pred, y = output
97+
y_pred, y = output[0].detach(), output[1].detach()
9698

9799
self._num_examples += y_pred.shape[0]
98100

ignite/metrics/fbeta.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def Fbeta(
1515
precision: Optional[Precision] = None,
1616
recall: Optional[Recall] = None,
1717
output_transform: Optional[Callable] = None,
18-
device: Optional[Union[str, torch.device]] = None,
18+
device: Union[str, torch.device] = torch.device("cpu"),
1919
) -> MetricsLambda:
2020
"""Calculates F-beta score
2121
@@ -28,7 +28,9 @@ def Fbeta(
2828
output_transform (callable, optional): a callable that is used to transform the
2929
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
3030
form expected by the metric. It is used only if precision or recall are not provided.
31-
device (str of torch.device, optional): optional device specification for internal storage.
31+
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
32+
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
33+
default, CPU.
3234
3335
Returns:
3436
MetricsLambda, F-beta metric

ignite/metrics/frequency.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Callable, Optional, Union
2+
13
import torch
24

35
import ignite.distributed as idist
@@ -35,7 +37,9 @@ class Frequency(Metric):
3537
# Epoch [2/10]: [50/100] 50%|█████ , wps=400 [00:17<00:35]
3638
"""
3739

38-
def __init__(self, output_transform=lambda x: x, device=None):
40+
def __init__(
41+
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
42+
):
3943
self._timer = None
4044
self._acc = None
4145
self._n = None

ignite/metrics/loss.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional, Sequence, Union
1+
from typing import Callable, Sequence, Union
22

33
import torch
44

@@ -26,7 +26,9 @@ class Loss(Metric):
2626
keywords arguments. If extra keywords arguments are provided they are passed to `loss_fn`.
2727
batch_size (callable): a callable taking a target tensor that returns the
2828
first dimension size (usually the batch size).
29-
device (str of torch.device, optional): unused argument.
29+
device (str or torch.device): specifies which device updates are accumulated on. Setting the
30+
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
31+
non-blocking. By default, CPU.
3032
3133
"""
3234

@@ -37,15 +39,15 @@ def __init__(
3739
loss_fn: Callable,
3840
output_transform: Callable = lambda x: x,
3941
batch_size: Callable = lambda x: len(x),
40-
device: Optional[Union[str, torch.device]] = None,
42+
device: Union[str, torch.device] = torch.device("cpu"),
4143
):
4244
super(Loss, self).__init__(output_transform, device=device)
4345
self._loss_fn = loss_fn
4446
self._batch_size = batch_size
4547

4648
@reinit__is_reduced
4749
def reset(self) -> None:
48-
self._sum = 0
50+
self._sum = torch.tensor(0.0, device=self._device)
4951
self._num_examples = 0
5052

5153
@reinit__is_reduced
@@ -55,17 +57,17 @@ def update(self, output: Sequence[Union[torch.Tensor, dict]]) -> None:
5557
kwargs = {}
5658
else:
5759
y_pred, y, kwargs = output
58-
average_loss = self._loss_fn(y_pred, y, **kwargs)
60+
average_loss = self._loss_fn(y_pred.detach(), y.detach(), **kwargs)
5961

6062
if len(average_loss.shape) != 0:
6163
raise ValueError("loss_fn did not return the average loss.")
6264

6365
n = self._batch_size(y)
64-
self._sum += average_loss.item() * n
66+
self._sum += average_loss.to(self._device) * n
6567
self._num_examples += n
6668

6769
@sync_all_reduce("_sum", "_num_examples")
6870
def compute(self) -> None:
6971
if self._num_examples == 0:
7072
raise NotComputableError("Loss must have at least one example before it can be computed.")
71-
return self._sum / self._num_examples
73+
return self._sum.item() / self._num_examples

ignite/metrics/mean_absolute_error.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ class MeanAbsoluteError(Metric):
1717

1818
@reinit__is_reduced
1919
def reset(self) -> None:
20-
self._sum_of_absolute_errors = 0.0
20+
self._sum_of_absolute_errors = torch.tensor(0.0, device=self._device)
2121
self._num_examples = 0
2222

2323
@reinit__is_reduced
2424
def update(self, output: Sequence[torch.Tensor]) -> None:
25-
y_pred, y = output
25+
y_pred, y = output[0].detach(), output[1].detach()
2626
absolute_errors = torch.abs(y_pred - y.view_as(y_pred))
27-
self._sum_of_absolute_errors += torch.sum(absolute_errors).item()
27+
self._sum_of_absolute_errors += torch.sum(absolute_errors).to(self._device)
2828
self._num_examples += y.shape[0]
2929

3030
@sync_all_reduce("_sum_of_absolute_errors", "_num_examples")
3131
def compute(self) -> Union[float, torch.Tensor]:
3232
if self._num_examples == 0:
3333
raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.")
34-
return self._sum_of_absolute_errors / self._num_examples
34+
return self._sum_of_absolute_errors.item() / self._num_examples

ignite/metrics/mean_pairwise_distance.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional, Sequence, Union
1+
from typing import Callable, Sequence, Union
22

33
import torch
44
from torch.nn.functional import pairwise_distance
@@ -21,26 +21,26 @@ def __init__(
2121
p: int = 2,
2222
eps: float = 1e-6,
2323
output_transform: Callable = lambda x: x,
24-
device: Optional[Union[str, torch.device]] = None,
24+
device: Union[str, torch.device] = torch.device("cpu"),
2525
):
2626
super(MeanPairwiseDistance, self).__init__(output_transform, device=device)
2727
self._p = p
2828
self._eps = eps
2929

3030
@reinit__is_reduced
3131
def reset(self):
32-
self._sum_of_distances = 0.0
32+
self._sum_of_distances = torch.tensor(0.0, device=self._device)
3333
self._num_examples = 0
3434

3535
@reinit__is_reduced
3636
def update(self, output: Sequence[torch.Tensor]) -> None:
37-
y_pred, y = output
37+
y_pred, y = output[0].detach(), output[1].detach()
3838
distances = pairwise_distance(y_pred, y, p=self._p, eps=self._eps)
39-
self._sum_of_distances += torch.sum(distances).item()
39+
self._sum_of_distances += torch.sum(distances).to(self._device)
4040
self._num_examples += y.shape[0]
4141

4242
@sync_all_reduce("_sum_of_distances", "_num_examples")
4343
def compute(self) -> Union[float, torch.Tensor]:
4444
if self._num_examples == 0:
4545
raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.")
46-
return self._sum_of_distances / self._num_examples
46+
return self._sum_of_distances.item() / self._num_examples

0 commit comments

Comments
 (0)