1
1
"""Base classes for the neighbor estimators"""
2
+ from __future__ import annotations
2
3
3
4
from abc import ABC
5
+ from typing import Any , Callable , Mapping , Tuple
4
6
5
7
import numpy as np
6
8
from sklearn .base import BaseEstimator , RegressorMixin
7
9
from sklearn .utils .validation import check_is_fitted as sklearn_check_is_fitted
10
+ from typing_extensions import Literal
8
11
9
12
from .. import FData , FDataGrid
10
13
from ..misc .metrics import l2_distance
14
+ from ..misc .metrics ._typing import Metric
15
+ from ..misc .metrics ._utils import _fit_metric
16
+ from ..representation ._typing import GridPoints , NDArrayFloat , NDArrayInt
11
17
12
18
13
- def _to_multivariate (fdatagrid ) :
19
+ def _to_multivariate (fdatagrid : FDataGrid ) -> NDArrayFloat :
14
20
r"""Returns the data matrix of a fdatagrid in flatten form compatible with
15
21
sklearn.
16
22
@@ -25,7 +31,12 @@ def _to_multivariate(fdatagrid):
25
31
return fdatagrid .data_matrix .reshape (fdatagrid .n_samples , - 1 )
26
32
27
33
28
- def _from_multivariate (data_matrix , grid_points , shape , ** kwargs ):
34
+ def _from_multivariate (
35
+ data_matrix : NDArrayFloat ,
36
+ grid_points : GridPoints ,
37
+ shape : Tuple [int , ...],
38
+ ** kwargs : Any ,
39
+ ) -> FDataGrid :
29
40
r"""Constructs a FDatagrid from the data matrix flattened.
30
41
31
42
Args:
@@ -42,7 +53,10 @@ def _from_multivariate(data_matrix, grid_points, shape, **kwargs):
42
53
return FDataGrid (data_matrix .reshape (shape ), grid_points , ** kwargs )
43
54
44
55
45
- def _to_multivariate_metric (metric , grid_points ):
56
+ def _to_multivariate_metric (
57
+ metric : Metric [FDataGrid ],
58
+ grid_points : GridPoints ,
59
+ ) -> Metric [NDArrayFloat ]:
46
60
r"""Transform a metric between FDatagrid in a sklearn compatible one.
47
61
48
62
Given a metric between FDatagrids returns a compatible metric used to
@@ -82,22 +96,36 @@ def _to_multivariate_metric(metric, grid_points):
82
96
# Shape -> (n_samples = 1, domain_dims...., image_dimension (-1))
83
97
shape = [1 ] + [len (axis ) for axis in grid_points ] + [- 1 ]
84
98
85
- def multivariate_metric (x , y , ** kwargs ):
99
+ def multivariate_metric (
100
+ x : NDArrayFloat ,
101
+ y : NDArrayFloat ,
102
+ ** kwargs : Any ,
103
+ ) -> NDArrayFloat :
86
104
87
- return metric (_from_multivariate (x , grid_points , shape ),
88
- _from_multivariate (y , grid_points , shape ),
89
- ** kwargs )
105
+ return metric (
106
+ _from_multivariate (x , grid_points , shape ),
107
+ _from_multivariate (y , grid_points , shape ),
108
+ ** kwargs ,
109
+ )
90
110
91
111
return multivariate_metric
92
112
93
113
94
114
class NeighborsBase (ABC , BaseEstimator ):
95
115
"""Base class for nearest neighbors estimators."""
96
116
97
- def __init__ (self , n_neighbors = None , radius = None ,
98
- weights = 'uniform' , algorithm = 'auto' ,
99
- leaf_size = 30 , metric = 'l2' , metric_params = None ,
100
- n_jobs = None , multivariate_metric = False ):
117
+ def __init__ (
118
+ self ,
119
+ n_neighbors : int | None = None ,
120
+ radius : float | None = None ,
121
+ weights : Literal ["uniform" , "distance" ] | Callable [[NDArrayFloat ], NDArrayFloat ] = "uniform" ,
122
+ algorithm : Literal ["auto" , "ball_tree" , "kd_tree" , "brute" ] = "auto" ,
123
+ leaf_size : int = 30 ,
124
+ metric : Literal ["precomputed" , "l2" ] | Metric [FDataGrid ] = 'l2' ,
125
+ metric_params : Mapping [str , Any ] | None = None ,
126
+ n_jobs : int | None = None ,
127
+ multivariate_metric : bool = False ,
128
+ ):
101
129
"""Initializes the nearest neighbors estimator"""
102
130
103
131
self .n_neighbors = n_neighbors
@@ -110,7 +138,7 @@ def __init__(self, n_neighbors=None, radius=None,
110
138
self .n_jobs = n_jobs
111
139
self .multivariate_metric = multivariate_metric
112
140
113
- def _check_is_fitted (self ):
141
+ def _check_is_fitted (self ) -> None :
114
142
"""Check if the estimator is fitted.
115
143
116
144
Raises:
@@ -119,7 +147,10 @@ def _check_is_fitted(self):
119
147
"""
120
148
sklearn_check_is_fitted (self , ['estimator_' ])
121
149
122
- def _transform_to_multivariate (self , X ):
150
+ def _transform_to_multivariate (
151
+ self ,
152
+ X : FDataGrid | None ,
153
+ ) -> NDArrayFloat | None :
123
154
"""Transform the input data to array form. If the metric is
124
155
precomputed it is not transformed.
125
156
@@ -131,9 +162,13 @@ def _transform_to_multivariate(self, X):
131
162
132
163
133
164
class NeighborsMixin :
134
- """Mixin class to train the neighbors models"""
165
+ """Mixin class to train the neighbors models. """
135
166
136
- def fit (self , X , y = None ):
167
+ def fit (
168
+ self ,
169
+ X : FDataGrid | NDArrayFloat ,
170
+ y : NDArrayFloat | NDArrayInt | None = None ,
171
+ ) -> NeighborsMixin :
137
172
"""Fit the model using X as training data and y as target values.
138
173
139
174
Args:
@@ -164,9 +199,13 @@ def fit(self, X, y=None):
164
199
else :
165
200
metric = self .metric
166
201
167
- sklearn_metric = _to_multivariate_metric (metric ,
168
- self ._grid_points )
202
+ _fit_metric (metric , X )
203
+ sklearn_metric = _to_multivariate_metric (
204
+ metric ,
205
+ self ._grid_points ,
206
+ )
169
207
else :
208
+ _fit_metric (self .metric , X )
170
209
sklearn_metric = self .metric
171
210
172
211
self .estimator_ = self ._init_estimator (sklearn_metric )
0 commit comments