Skip to content

Commit a70513b

Browse files
justusschockBorda
authored andcommitted
add explicit check for dtype to convert to
1 parent b0739e1 commit a70513b

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

pytorch_lightning/metrics/converters.py

+4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def _convert_to_tensor(data: Any) -> Any:
7777
# is not array of object
7878
elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
7979
return torch.from_numpy(data)
80+
elif isinstance(data, torch.Tensor):
81+
return data
8082

8183
raise TypeError("The given type ('%s') cannot be converted to a tensor!" % type(data).__name__)
8284

@@ -94,6 +96,8 @@ def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) ->
9496
return data.cpu().detach().numpy()
9597
elif isinstance(data, numbers.Number):
9698
return np.array([data])
99+
elif isinstance(data, np.ndarray):
100+
return data
97101

98102
raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)
99103

0 commit comments

Comments
 (0)