Skip to content

Commit 80f4f23

Browse files
committed
Fit metric in k-NN.
1 parent 9abd38d commit 80f4f23

File tree

1 file changed

+56
-17
lines changed

1 file changed

+56
-17
lines changed

Diff for: skfda/ml/_neighbors_base.py

+56-17
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
"""Base classes for the neighbor estimators"""
2+
from __future__ import annotations
23

34
from abc import ABC
5+
from typing import Any, Callable, Mapping, Tuple
46

57
import numpy as np
68
from sklearn.base import BaseEstimator, RegressorMixin
79
from sklearn.utils.validation import check_is_fitted as sklearn_check_is_fitted
10+
from typing_extensions import Literal
811

912
from .. import FData, FDataGrid
1013
from ..misc.metrics import l2_distance
14+
from ..misc.metrics._typing import Metric
15+
from ..misc.metrics._utils import _fit_metric
16+
from ..representation._typing import GridPoints, NDArrayFloat, NDArrayInt
1117

1218

13-
def _to_multivariate(fdatagrid):
19+
def _to_multivariate(fdatagrid: FDataGrid) -> NDArrayFloat:
1420
r"""Returns the data matrix of a fdatagrid in flatten form compatible with
1521
sklearn.
1622
@@ -25,7 +31,12 @@ def _to_multivariate(fdatagrid):
2531
return fdatagrid.data_matrix.reshape(fdatagrid.n_samples, -1)
2632

2733

28-
def _from_multivariate(data_matrix, grid_points, shape, **kwargs):
34+
def _from_multivariate(
35+
data_matrix: NDArrayFloat,
36+
grid_points: GridPoints,
37+
shape: Tuple[int, ...],
38+
**kwargs: Any,
39+
) -> FDataGrid:
2940
r"""Constructs a FDatagrid from the data matrix flattened.
3041
3142
Args:
@@ -42,7 +53,10 @@ def _from_multivariate(data_matrix, grid_points, shape, **kwargs):
4253
return FDataGrid(data_matrix.reshape(shape), grid_points, **kwargs)
4354

4455

45-
def _to_multivariate_metric(metric, grid_points):
56+
def _to_multivariate_metric(
57+
metric: Metric[FDataGrid],
58+
grid_points: GridPoints,
59+
) -> Metric[NDArrayFloat]:
4660
r"""Transform a metric between FDatagrid in a sklearn compatible one.
4761
4862
Given a metric between FDatagrids returns a compatible metric used to
@@ -82,22 +96,36 @@ def _to_multivariate_metric(metric, grid_points):
8296
# Shape -> (n_samples = 1, domain_dims...., image_dimension (-1))
8397
shape = [1] + [len(axis) for axis in grid_points] + [-1]
8498

85-
def multivariate_metric(x, y, **kwargs):
99+
def multivariate_metric(
100+
x: NDArrayFloat,
101+
y: NDArrayFloat,
102+
**kwargs: Any,
103+
) -> NDArrayFloat:
86104

87-
return metric(_from_multivariate(x, grid_points, shape),
88-
_from_multivariate(y, grid_points, shape),
89-
**kwargs)
105+
return metric(
106+
_from_multivariate(x, grid_points, shape),
107+
_from_multivariate(y, grid_points, shape),
108+
**kwargs,
109+
)
90110

91111
return multivariate_metric
92112

93113

94114
class NeighborsBase(ABC, BaseEstimator):
95115
"""Base class for nearest neighbors estimators."""
96116

97-
def __init__(self, n_neighbors=None, radius=None,
98-
weights='uniform', algorithm='auto',
99-
leaf_size=30, metric='l2', metric_params=None,
100-
n_jobs=None, multivariate_metric=False):
117+
def __init__(
118+
self,
119+
n_neighbors: int | None = None,
120+
radius: float | None = None,
121+
weights: Literal["uniform", "distance"] | Callable[[NDArrayFloat], NDArrayFloat] = "uniform",
122+
algorithm: Literal["auto", "ball_tree", "kd_tree", "brute"] = "auto",
123+
leaf_size: int = 30,
124+
metric: Literal["precomputed", "l2"] | Metric[FDataGrid] = 'l2',
125+
metric_params: Mapping[str, Any] | None = None,
126+
n_jobs: int | None = None,
127+
multivariate_metric: bool = False,
128+
):
101129
"""Initializes the nearest neighbors estimator"""
102130

103131
self.n_neighbors = n_neighbors
@@ -110,7 +138,7 @@ def __init__(self, n_neighbors=None, radius=None,
110138
self.n_jobs = n_jobs
111139
self.multivariate_metric = multivariate_metric
112140

113-
def _check_is_fitted(self):
141+
def _check_is_fitted(self) -> None:
114142
"""Check if the estimator is fitted.
115143
116144
Raises:
@@ -119,7 +147,10 @@ def _check_is_fitted(self):
119147
"""
120148
sklearn_check_is_fitted(self, ['estimator_'])
121149

122-
def _transform_to_multivariate(self, X):
150+
def _transform_to_multivariate(
151+
self,
152+
X: FDataGrid | None,
153+
) -> NDArrayFloat | None:
123154
"""Transform the input data to array form. If the metric is
124155
precomputed it is not transformed.
125156
@@ -131,9 +162,13 @@ def _transform_to_multivariate(self, X):
131162

132163

133164
class NeighborsMixin:
134-
"""Mixin class to train the neighbors models"""
165+
"""Mixin class to train the neighbors models."""
135166

136-
def fit(self, X, y=None):
167+
def fit(
168+
self,
169+
X: FDataGrid | NDArrayFloat,
170+
y: NDArrayFloat | NDArrayInt | None = None,
171+
) -> NeighborsMixin:
137172
"""Fit the model using X as training data and y as target values.
138173
139174
Args:
@@ -164,9 +199,13 @@ def fit(self, X, y=None):
164199
else:
165200
metric = self.metric
166201

167-
sklearn_metric = _to_multivariate_metric(metric,
168-
self._grid_points)
202+
_fit_metric(metric, X)
203+
sklearn_metric = _to_multivariate_metric(
204+
metric,
205+
self._grid_points,
206+
)
169207
else:
208+
_fit_metric(self.metric, X)
170209
sklearn_metric = self.metric
171210

172211
self.estimator_ = self._init_estimator(sklearn_metric)

0 commit comments

Comments
 (0)