Skip to content

This is my pull request for the issue 647 #649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/modules/misc/scoring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@ Only scores that support multioutput are included.
skfda.misc.scoring.mean_squared_error
skfda.misc.scoring.mean_squared_log_error
skfda.misc.scoring.r2_score
skfda.misc.scoring.root_mean_squared_error
skfda.misc.scoring.root_mean_squared_log_error

.. warning::

The `squared` parameter in `mean_squared_error` and `mean_squared_log_error`
is deprecated and will be removed in a future version. Use
`root_mean_squared_error` and `root_mean_squared_log_error` instead.
151 changes: 151 additions & 0 deletions docs/sg_execution_times.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@

:orphan:

.. _sphx_glr_sg_execution_times:


Computation times
=================
**00:07.828** total execution time for 39 files **from all galleries**:

.. container::

.. raw:: html

<style scoped>
<link href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/5.3.0/css/bootstrap.min.css" rel="stylesheet" />
<link href="https://cdn.datatables.net/1.13.6/css/dataTables.bootstrap5.min.css" rel="stylesheet" />
</style>
<script src="https://code.jquery.com/jquery-3.7.0.js"></script>
<script src="https://cdn.datatables.net/1.13.6/js/jquery.dataTables.min.js"></script>
<script src="https://cdn.datatables.net/1.13.6/js/dataTables.bootstrap5.min.js"></script>
<script type="text/javascript" class="init">
$(document).ready( function () {
$('table.sg-datatable').DataTable({order: [[1, 'desc']]});
} );
</script>

.. list-table::
:header-rows: 1
:class: table table-striped sg-datatable

* - Example
- Time
- Mem (MB)
* - :ref:`sphx_glr_auto_examples_plot_aemet_unsupervised.py` (``..\examples\plot_aemet_unsupervised.py``)
- 00:07.828
- 0.0
* - :ref:`sphx_glr_auto_examples_expand_skfda_plot_basis_subclass.py` (``..\examples\expand_skfda\plot_basis_subclass.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_expand_skfda_plot_new_evaluator.py` (``..\examples\expand_skfda\plot_new_evaluator.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_boxplot.py` (``..\examples\plot_boxplot.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_classification_methods.py` (``..\examples\plot_classification_methods.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_clustering.py` (``..\examples\plot_clustering.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_composition.py` (``..\examples\plot_composition.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_depth_classification.py` (``..\examples\plot_depth_classification.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_discrete_representation.py` (``..\examples\plot_discrete_representation.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_elastic_registration.py` (``..\examples\plot_elastic_registration.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_explore.py` (``..\examples\plot_explore.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_extrapolation.py` (``..\examples\plot_extrapolation.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_fpca.py` (``..\examples\plot_fpca.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_fpca_inverse_transform_outl_detection.py` (``..\examples\plot_fpca_inverse_transform_outl_detection.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_fpca_regression.py` (``..\examples\plot_fpca_regression.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_functional_regression.py` (``..\examples\plot_functional_regression.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_interpolation.py` (``..\examples\plot_interpolation.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_k_neighbors_classification.py` (``..\examples\plot_k_neighbors_classification.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_kernel_regression.py` (``..\examples\plot_kernel_regression.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_kernel_smoothing.py` (``..\examples\plot_kernel_smoothing.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_landmark_registration.py` (``..\examples\plot_landmark_registration.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_landmark_shift.py` (``..\examples\plot_landmark_shift.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_magnitude_shape.py` (``..\examples\plot_magnitude_shape.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_magnitude_shape_synthetic.py` (``..\examples\plot_magnitude_shape_synthetic.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_neighbors_functional_regression.py` (``..\examples\plot_neighbors_functional_regression.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_neighbors_scalar_regression.py` (``..\examples\plot_neighbors_scalar_regression.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_oneway.py` (``..\examples\plot_oneway.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_oneway_synthetic.py` (``..\examples\plot_oneway_synthetic.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_pairwise_alignment.py` (``..\examples\plot_pairwise_alignment.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_phonemes_classification.py` (``..\examples\plot_phonemes_classification.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_radius_neighbors_classification.py` (``..\examples\plot_radius_neighbors_classification.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_representation.py` (``..\examples\plot_representation.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_shift_registration.py` (``..\examples\plot_shift_registration.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_surface_boxplot.py` (``..\examples\plot_surface_boxplot.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_examples_plot_tecator_regression.py` (``..\examples\plot_tecator_regression.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_tutorial_plot_basis_representation.py` (``..\tutorial\plot_basis_representation.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_tutorial_plot_getting_data.py` (``..\tutorial\plot_getting_data.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_tutorial_plot_introduction.py` (``..\tutorial\plot_introduction.py``)
- 00:00.000
- 0.0
* - :ref:`sphx_glr_auto_tutorial_plot_skfda_sklearn.py` (``..\tutorial\plot_skfda_sklearn.py``)
- 00:00.000
- 0.0
160 changes: 160 additions & 0 deletions skfda/misc/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,8 @@
are taken.
multioutput: Defines format of the return.
squared: If True returns MSE value, if False returns RMSE value.
Deprecated since version 1.X: The 'squared' parameter will be removed in a future version.
Use `root_mean_squared_error` instead if `squared=False` was used.

Returns:
Mean squared error.
Expand All @@ -750,6 +752,13 @@
multioutput = 'raw_values', ndarray.

"""
if not squared:
warnings.warn(
"The 'squared' parameter is deprecated and will be removed in a future version. "
"Use `root_mean_squared_error` instead.",
FutureWarning,
stacklevel=2,
)
function = (
sklearn.metrics.mean_squared_error
if squared
Expand Down Expand Up @@ -826,6 +835,24 @@

return _multioutput_score_basis(y_true, multioutput, _mse_func)

@singledispatch
def root_mean_squared_error(
y_true: DataType,
y_pred: DataType,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: MultiOutputType = "uniform_average",
) -> float | DataType:
"""Root Mean Squared Error for functional data."""

return mean_squared_error(

Check warning on line 848 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L848

Added line #L848 was not covered by tests
y_true,
y_pred,
sample_weight=sample_weight,
multioutput=multioutput,
squared=False, # Apply the square root directly
)


@overload
def mean_squared_log_error(
Expand Down Expand Up @@ -924,6 +951,13 @@
multioutput = 'raw_values', ndarray.

"""
if not squared:
warnings.warn(
"The 'squared' parameter is deprecated and will be removed in a future version. "
"Use `root_mean_squared_log_error` instead.",
FutureWarning,
stacklevel=2,
)
function = (
sklearn.metrics.mean_squared_log_error
if squared
Expand Down Expand Up @@ -1025,6 +1059,132 @@

return _multioutput_score_basis(y_true, multioutput, _msle_func)

@overload
def root_mean_squared_log_error(
y_true: DataType,
y_pred: DataType,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: Literal['uniform_average'] = 'uniform_average',
) -> float:
pass # noqa: WPS428

Check warning on line 1070 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1070

Added line #L1070 was not covered by tests


@overload
def root_mean_squared_log_error(
y_true: DataType,
y_pred: DataType,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: Literal['raw_values'],
) -> DataType:
pass # noqa: WPS428

Check warning on line 1081 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1081

Added line #L1081 was not covered by tests


@singledispatch
def root_mean_squared_log_error(
y_true: DataType,
y_pred: DataType,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: MultiOutputType = 'uniform_average',
) -> float | DataType:
r"""Root Mean Squared Log Error for :class:`~skfda.representation.FData`.

This function applies the same logic as `mean_squared_log_error`, but
directly takes the square root of the result.

Args:
y_true: True target values.
y_pred: Predicted values.
sample_weight: Sample weights.
multioutput: Return format (raw values or uniform average).

Returns:
Root mean squared logarithmic error.
"""
return sklearn.metrics.root_mean_squared_log_error(

Check warning on line 1106 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1106

Added line #L1106 was not covered by tests
y_true,
y_pred,
sample_weight=sample_weight,
multioutput=multioutput,
)

@root_mean_squared_log_error.register # type: ignore[attr-defined, misc]
def _root_mean_squared_log_error_fdatagrid(
y_true: FDataGrid,
y_pred: FDataGrid,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: MultiOutputType = 'uniform_average',
) -> float | FDataGrid:

if np.any(y_true.data_matrix < 0) or np.any(y_pred.data_matrix < 0):
raise ValueError(

Check warning on line 1123 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1122-L1123

Added lines #L1122 - L1123 were not covered by tests
"Root Mean Squared Logarithmic Error cannot be used when "
"targets functions have negative values.",
)

return root_mean_squared_error(

Check warning on line 1128 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1128

Added line #L1128 was not covered by tests
np.log1p(y_true),
np.log1p(y_pred),
sample_weight=sample_weight,
multioutput=multioutput,
)

@root_mean_squared_log_error.register # type: ignore[attr-defined, misc]
def _root_mean_squared_log_error_fdatairregular(
y_true: FDataIrregular,
y_pred: FDataIrregular,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: MultiOutputType = 'uniform_average',
) -> float:

if np.any(y_true.values < 0) or np.any(y_pred.values < 0):
raise ValueError(

Check warning on line 1145 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1144-L1145

Added lines #L1144 - L1145 were not covered by tests
"Root Mean Squared Logarithmic Error cannot be used when "
"targets functions have negative values.",
)

return root_mean_squared_error(

Check warning on line 1150 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1150

Added line #L1150 was not covered by tests
np.log1p(y_true),
np.log1p(y_pred),
sample_weight=sample_weight,
multioutput=multioutput,
)

@root_mean_squared_log_error.register # type: ignore[attr-defined, misc]
def _root_mean_squared_log_error_fdatabasis(
y_true: FDataBasis,
y_pred: FDataBasis,
*,
sample_weight: NDArrayFloat | None = None,
multioutput: MultiOutputType = 'uniform_average',
) -> float:

def _rmsle_func(x: EvalPointsType) -> NDArrayFloat:
y_true_eval = y_true(x)
y_pred_eval = y_pred(x)

Check warning on line 1168 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1166-L1168

Added lines #L1166 - L1168 were not covered by tests

if np.any(y_true_eval < 0) or np.any(y_pred_eval < 0):
raise ValueError(

Check warning on line 1171 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1170-L1171

Added lines #L1170 - L1171 were not covered by tests
"Root Mean Squared Logarithmic Error cannot be used when "
"targets functions have negative values.",
)

error: NDArrayFloat = np.sqrt(np.average(

Check warning on line 1176 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1176

Added line #L1176 was not covered by tests
(np.log1p(y_true_eval) - np.log1p(y_pred_eval)) ** 2,
weights=sample_weight,
axis=0,
))

# Verify that the error only contains 1 input point
assert error.shape[0] == 1
return error[0] # type: ignore [no-any-return]

Check warning on line 1184 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1183-L1184

Added lines #L1183 - L1184 were not covered by tests

return _multioutput_score_basis(y_true, multioutput, _rmsle_func)

Check warning on line 1186 in skfda/misc/scoring.py

View check run for this annotation

Codecov / codecov/patch

skfda/misc/scoring.py#L1186

Added line #L1186 was not covered by tests


@overload
def r2_score(
Expand Down