Skip to content

This is my pull request for the issue 647 #649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/modules/misc/scoring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@ Only scores that support multioutput are included.
skfda.misc.scoring.mean_squared_error
skfda.misc.scoring.mean_squared_log_error
skfda.misc.scoring.r2_score
skfda.misc.scoring.root_mean_squared_error
skfda.misc.scoring.root_mean_squared_log_error

.. warning::

The `squared` parameter in `mean_squared_error` and `mean_squared_log_error`
is deprecated and will be removed in a future version. Use
`root_mean_squared_error` and `root_mean_squared_log_error` instead.
261 changes: 261 additions & 0 deletions skfda/misc/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,8 @@ def mean_squared_error(
are taken.
multioutput: Defines format of the return.
squared: If True returns MSE value, if False returns RMSE value.
Deprecated since version 1.X: The 'squared' parameter will be removed in a future version.
Use `root_mean_squared_error` instead if `squared=False` was used.

Returns:
Mean squared error.
Expand All @@ -750,6 +752,13 @@ def mean_squared_error(
multioutput = 'raw_values', ndarray.

"""
if not squared:
warnings.warn(
"The 'squared' parameter is deprecated and will be removed in a future version. "
"Use `root_mean_squared_error` instead.",
FutureWarning,
stacklevel=2,
)
function = (
sklearn.metrics.mean_squared_error
if squared
Expand Down Expand Up @@ -826,6 +835,115 @@ def _mse_func(x: EvalPointsType) -> NDArrayFloat: # noqa: WPS430

return _multioutput_score_basis(y_true, multioutput, _mse_func)

@overload
def root_mean_squared_error(
y_true: np.ndarray,
y_pred: np.ndarray,
*,
sample_weight: np.ndarray | None = ...,
multioutput: Literal['uniform_average'] = ...,
) -> float:
pass # noqa: WPS428

@overload
def root_mean_squared_error(
y_true: np.ndarray,
y_pred: np.ndarray,
*,
sample_weight: np.ndarray | None = ...,
multioutput: Literal['raw_values'],
) -> np.ndarray:
pass # noqa: WPS428

@overload
def root_mean_squared_error(
y_true: FData,
y_pred: FData,
*,
sample_weight: np.ndarray | None = ...,
multioutput: Literal['uniform_average'] = ...,
) -> float:
pass # noqa: WPS428

@overload
def root_mean_squared_error(
y_true: FData,
y_pred: FData,
*,
sample_weight: np.ndarray | None = ...,
multioutput: Literal['raw_values'],
) -> FData:
pass # noqa: WPS428

@singledispatch
def root_mean_squared_error(
y_true: DataType,
y_pred: DataType,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: MultiOutputType = "uniform_average",
) -> float | DataType:
r"""Root Mean Squared Error for :class:`~skfda.representation.FData`.
With :math:`y\_true = (X_1, X_2, ..., X_n)` being the real values,
:math:`t\_pred = (\hat{X}_1, \hat{X}_2, ..., \hat{X}_n)` being the
estimated and :math:`sample\_weight = (w_1, w_2, ..., w_n)`, the error is
calculated as

.. math::
RMSE(y\_true, y\_pred)(t) = \sqrt{\frac{1}{\sum w_i}
\sum_{i=1}^n w_i(X_i(t) - \hat{X}_i(t))^2}

This is the square root of MSE (Mean Squared Error).

Args:
y_true: Correct target values.
y_pred: Estimated values.
sample_weight: Sample weights. By default, uniform weights
are taken.
multioutput: Defines format of the return.

Returns:
Root mean squared error.
"""
# Call sklearn's implementation for numpy arrays
if isinstance(y_true, np.ndarray) and isinstance(y_pred, np.ndarray):
return sklearn.metrics.root_mean_squared_error(
y_true,
y_pred,
sample_weight=sample_weight,
multioutput=multioutput,
)

# For FData objects, I use our own implementation without any warnings
return mean_squared_error(
y_true,
y_pred,
sample_weight=sample_weight,
multioutput=multioutput,
squared=False,
)

@root_mean_squared_error.register
def _(y_true: np.ndarray, y_pred: np.ndarray, *, sample_weight=None, multioutput: MultiOutputType = "uniform_average"):
"""For the NumPy array."""
try:
#We try to use the function if available in scikit-learn
return sklearn.metrics.root_mean_squared_error(
y_true, y_pred, sample_weight=sample_weight, multioutput=multioutput
)
except AttributeError:
# Fallback for older scikit-learn versions
return np.sqrt(sklearn.metrics.mean_squared_error(
y_true, y_pred, sample_weight=sample_weight, multioutput=multioutput
))

@root_mean_squared_error.register
def _(y_true: FData, y_pred: FData, *, sample_weight=None, multioutput="uniform_average"):
"""For FData."""
mse = mean_squared_error(y_true, y_pred, sample_weight=sample_weight, multioutput=multioutput,squared=False)
return mse #** 0.5 jsp pourquoi ça renvoie Manual RMSE calculation: 0.5361902647381803
#skfda RMSE result: 0.6390096504226938


@overload
def mean_squared_log_error(
Expand Down Expand Up @@ -924,6 +1042,13 @@ def mean_squared_log_error(
multioutput = 'raw_values', ndarray.

"""
if not squared:
warnings.warn(
"The 'squared' parameter is deprecated and will be removed in a future version. "
"Use `root_mean_squared_log_error` instead.",
FutureWarning,
stacklevel=2,
)
function = (
sklearn.metrics.mean_squared_log_error
if squared
Expand Down Expand Up @@ -1025,6 +1150,142 @@ def _msle_func(x: EvalPointsType) -> NDArrayFloat: # noqa: WPS430

return _multioutput_score_basis(y_true, multioutput, _msle_func)

@overload
def root_mean_squared_log_error(
y_true: DataType,
y_pred: DataType,
*,
sample_weight: NDArrayFloat | None = ...,
multioutput: Literal['uniform_average'] = ...,
) -> float:
pass # noqa: WPS428


@overload
def root_mean_squared_log_error(
y_true: DataType,
y_pred: DataType,
*,
sample_weight: NDArrayFloat | None = ...,
multioutput: Literal['raw_values'],
) -> DataType:
pass # noqa: WPS428


@singledispatch
def root_mean_squared_log_error(
y_true: DataType,
y_pred: DataType,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: MultiOutputType = 'uniform_average',
) -> float | DataType:
r"""Root Mean Squared Log Error for :class:`~skfda.representation.FData`.

This function applies the same logic as `mean_squared_log_error`, but
directly takes the square root of the result.

Args:
y_true: True target values.
y_pred: Predicted values.
sample_weight: Sample weights.
multioutput: Return format (raw values or uniform average).

Returns:
Root mean squared logarithmic error.
"""
try:
#We try to use the function if available in scikit-learn
return sklearn.metrics.root_mean_squared_log_error(
y_true,
y_pred,
sample_weight=sample_weight,
multioutput=multioutput,
)
except AttributeError:
# Fallback for older sklearn versions that don't have root_mean_squared_log_error
return np.sqrt(sklearn.metrics.mean_squared_log_error(
y_true,
y_pred,
sample_weight=sample_weight,
multioutput=multioutput,
))

@root_mean_squared_log_error.register # type: ignore[attr-defined, misc]
def _root_mean_squared_log_error_fdatagrid(
y_true: FDataGrid,
y_pred: FDataGrid,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: MultiOutputType = 'uniform_average',
) -> float | FDataGrid:

if np.any(y_true.data_matrix < 0) or np.any(y_pred.data_matrix < 0):
raise ValueError(
"Root Mean Squared Logarithmic Error cannot be used when "
"targets functions have negative values.",
)

return root_mean_squared_error(
np.log1p(y_true),
np.log1p(y_pred),
sample_weight=sample_weight,
multioutput=multioutput,
)

@root_mean_squared_log_error.register # type: ignore[attr-defined, misc]
def _root_mean_squared_log_error_fdatairregular(
y_true: FDataIrregular,
y_pred: FDataIrregular,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: MultiOutputType = 'uniform_average',
) -> float:

if np.any(y_true.values < 0) or np.any(y_pred.values < 0):
raise ValueError(
"Root Mean Squared Logarithmic Error cannot be used when "
"targets functions have negative values.",
)

return root_mean_squared_error(
np.log1p(y_true),
np.log1p(y_pred),
sample_weight=sample_weight,
multioutput=multioutput,
)

@root_mean_squared_log_error.register # type: ignore[attr-defined, misc]
def _root_mean_squared_log_error_fdatabasis(
y_true: FDataBasis,
y_pred: FDataBasis,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: MultiOutputType = 'uniform_average',
) -> float:

def _rmsle_func(x: EvalPointsType) -> NDArrayFloat:
y_true_eval = y_true(x)
y_pred_eval = y_pred(x)

if np.any(y_true_eval < 0) or np.any(y_pred_eval < 0):
raise ValueError(
"Root Mean Squared Logarithmic Error cannot be used when "
"targets functions have negative values.",
)

error: NDArrayFloat = np.sqrt(np.average(
(np.log1p(y_true_eval) - np.log1p(y_pred_eval)) ** 2,
weights=sample_weight,
axis=0,
))

# Verify that the error only contains 1 input point
assert error.shape[0] == 1
return error[0] # type: ignore [no-any-return]

return _multioutput_score_basis(y_true, multioutput, _rmsle_func)


@overload
def r2_score(
Expand Down
41 changes: 41 additions & 0 deletions skfda/tests/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,28 @@ def test_mean_squared_error_basis(self) -> None:

# integrate 1/2 * (-x + x^2)^2
self.assertAlmostEqual(mse, 2.85)

def test_root_mean_squared_error_nparray(self) -> None:
"""Test Root Mean Squared Error for ndarray."""
y_true = np.array([3.0, -0.5, 2.0, 7.0])
y_pred = np.array([2.5, 0.0, 2.1, 7.8])

# Test without sample_weight
rmse_np = root_mean_squared_error(y_true, y_pred)

# Expected result manually
expected_result = np.sqrt(np.mean((y_true - y_pred) ** 2))
self.assertAlmostEqual(rmse_np, expected_result)

def test_root_mean_squared_error_basis(self) -> None:
"""Test Root Mean Squared Error for FDataBasis."""
y_true, y_pred = _create_data_basis()

rmse = root_mean_squared_error(y_true, y_pred)

# integrate sqrt(1/2 * (-x + x^2)^2)
self.assertAlmostEqual(rmse, np.sqrt(2.85))

def test_mean_squared_log_error_basis(self) -> None:
"""Test Mean Squared Log Error for FDataBasis."""
y_true, y_pred = _create_data_basis()
Expand All @@ -228,6 +249,26 @@ def test_mean_squared_log_error_basis(self) -> None:

# integrate 1/2*(log(1 + 4 + 5x + 6x^2) - log(1 + 4 + 6x + 5x^2))^2
self.assertAlmostEqual(msle, 0.00107583)
def test_root_mean_squared_log_error_nparray(self) -> None:
"""Test Root Mean Squared Log Error for ndarray."""
y_true = np.array([3.0, -0.5, 2.0, 7.0])
y_pred = np.array([2.5, 0.0, 2.1, 7.8])

# Test without sample_weight
rmsle_np = root_mean_squared_log_error(y_true, y_pred)

# Expected result manually
expected_result = np.sqrt(np.mean((np.log1p(y_pred) - np.log1p(y_true)) ** 2))
self.assertAlmostEqual(rmsle_np, expected_result)

def test_root_mean_squared_log_error_basis(self) -> None:
"""Test Mean Squared Log Error for FDataBasis."""
y_true, y_pred = _create_data_basis()

msle = root_mean_squared_log_error(y_true, y_pred)

# integrate sqrt(1/2*(log(1 + 4 + 5x + 6x^2) - log(1 + 4 + 6x + 5x^2))^2)
self.assertAlmostEqual(msle, 0.00107583)

def test_r2_score_basis(self) -> None:
"""Test R2 Score for FDataBasis."""
Expand Down