4
4
from pytorch_lightning .metrics .functional .reduction import reduce
5
5
6
6
7
+ def mse (
8
+ pred : torch .Tensor ,
9
+ target : torch .Tensor ,
10
+ reduction : str = 'elementwise_mean'
11
+ ) -> torch .Tensor :
12
+ """
13
+ Computes mean squared error
14
+
15
+ Args:
16
+ pred: estimated labels
17
+ target: ground truth labels
18
+ reduction: method for reducing mse (default: takes the mean)
19
+ Available reduction methods:
20
+
21
+ - elementwise_mean: takes the mean
22
+ - none: pass array
23
+ - sum: add elements
24
+
25
+ Return:
26
+ Tensor with MSE
27
+
28
+ Example:
29
+
30
+ >>> x = torch.tensor([0., 1, 2, 3])
31
+ >>> y = torch.tensor([0., 1, 2, 2])
32
+ >>> mse(x, y)
33
+ tensor(0.2500)
34
+
35
+ """
36
+ mse = F .mse_loss (pred , target , reduction = 'none' )
37
+ mse = reduce (mse , reduction = reduction )
38
+ return mse
39
+
40
+
41
+ def rmse (
42
+ pred : torch .Tensor ,
43
+ target : torch .Tensor ,
44
+ reduction : str = 'elementwise_mean'
45
+ ) -> torch .Tensor :
46
+ """
47
+ Computes root mean squared error
48
+
49
+ Args:
50
+ pred: estimated labels
51
+ target: ground truth labels
52
+ reduction: method for reducing rmse (default: takes the mean)
53
+ Available reduction methods:
54
+
55
+ - elementwise_mean: takes the mean
56
+ - none: pass array
57
+ - sum: add elements
58
+
59
+ Return:
60
+ Tensor with RMSE
61
+
62
+
63
+ >>> x = torch.tensor([0., 1, 2, 3])
64
+ >>> y = torch.tensor([0., 1, 2, 2])
65
+ >>> rmse(x, y)
66
+ tensor(0.5000)
67
+
68
+ """
69
+ rmse = torch .sqrt (mse (pred , target , reduction = reduction ))
70
+ return rmse
71
+
72
+
73
+ def mae (
74
+ pred : torch .Tensor ,
75
+ target : torch .Tensor ,
76
+ reduction : str = 'elementwise_mean'
77
+ ) -> torch .Tensor :
78
+ """
79
+ Computes mean absolute error
80
+
81
+ Args:
82
+ pred: estimated labels
83
+ target: ground truth labels
84
+ reduction: method for reducing mae (default: takes the mean)
85
+ Available reduction methods:
86
+
87
+ - elementwise_mean: takes the mean
88
+ - none: pass array
89
+ - sum: add elements
90
+
91
+ Return:
92
+ Tensor with MAE
93
+
94
+ Example:
95
+
96
+ >>> x = torch.tensor([0., 1, 2, 3])
97
+ >>> y = torch.tensor([0., 1, 2, 2])
98
+ >>> mae(x, y)
99
+ tensor(0.2500)
100
+
101
+ """
102
+ mae = F .l1_loss (pred , target , reduction = 'none' )
103
+ mae = reduce (mae , reduction = reduction )
104
+ return mae
105
+
106
+
107
+ def rmsle (
108
+ pred : torch .Tensor ,
109
+ target : torch .Tensor ,
110
+ reduction : str = 'elementwise_mean'
111
+ ) -> torch .Tensor :
112
+ """
113
+ Computes root mean squared log error
114
+
115
+ Args:
116
+ pred: estimated labels
117
+ target: ground truth labels
118
+ reduction: method for reducing rmsle (default: takes the mean)
119
+ Available reduction methods:
120
+
121
+ - elementwise_mean: takes the mean
122
+ - none: pass array
123
+ - sum: add elements
124
+
125
+ Return:
126
+ Tensor with RMSLE
127
+
128
+ Example:
129
+
130
+ >>> x = torch.tensor([0., 1, 2, 3])
131
+ >>> y = torch.tensor([0., 1, 2, 2])
132
+ >>> rmsle(x, y)
133
+ tensor(0.0207)
134
+
135
+ """
136
+ rmsle = mse (torch .log (pred + 1 ), torch .log (target + 1 ), reduction = reduction )
137
+ return rmsle
138
+
139
+
7
140
def psnr (
8
141
pred : torch .Tensor ,
9
142
target : torch .Tensor ,
@@ -12,14 +145,22 @@ def psnr(
12
145
reduction : str = 'elementwise_mean'
13
146
) -> torch .Tensor :
14
147
"""
15
- Computes the peak signal-to-noise ratio metric
148
+ Computes the peak signal-to-noise ratio
16
149
17
150
Args:
18
151
pred: estimated signal
19
152
target: groun truth signal
20
- data_range: the range of the data. If None, it is determined from the data (max - min).
153
+ data_range: the range of the data. If None, it is determined from the data (max - min)
21
154
base: a base of a logarithm to use (default: 10)
22
155
reduction: method for reducing psnr (default: takes the mean)
156
+ Available reduction methods:
157
+
158
+ - elementwise_mean: takes the mean
159
+ - none: pass array
160
+ - sum add elements
161
+
162
+ Return:
163
+ Tensor with PSNR score
23
164
24
165
Example:
25
166
@@ -29,12 +170,15 @@ def psnr(
29
170
>>> metric = PSNR()
30
171
>>> metric(pred, target)
31
172
tensor(2.5527)
173
+
32
174
"""
33
175
34
176
if data_range is None :
35
177
data_range = max (target .max () - target .min (), pred .max () - pred .min ())
36
178
else :
37
179
data_range = torch .tensor (float (data_range ))
38
- mse = F .mse_loss (pred .view (- 1 ), target .view (- 1 ), reduction = reduction )
39
- psnr_base_e = 2 * torch .log (data_range ) - torch .log (mse )
40
- return psnr_base_e * (10 / torch .log (torch .tensor (base )))
180
+
181
+ mse_score = mse (pred .view (- 1 ), target .view (- 1 ), reduction = reduction )
182
+ psnr_base_e = 2 * torch .log (data_range ) - torch .log (mse_score )
183
+ psnr = psnr_base_e * (10 / torch .log (torch .tensor (base )))
184
+ return psnr
0 commit comments