-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH - check datafit + penalty
compatibility with solver
#137
Changes from all commits
0154215
3c70b85
ee0d29c
8831cd2
9abd149
d20229a
310c572
1872e4e
cd1ba1c
4c9b4d4
c3d01c4
eacea14
573cb78
8dedf18
19fbe63
3b65445
ef64f82
ddc67b6
5d75070
5d1b618
ce118ca
f38d12e
d1aeb4c
b4e2d8e
92f5d3b
af7ae2c
bbfa438
850f04f
84a2cf8
05e5cff
4aefc8c
626fe77
b50ae91
e81409f
39fa7b9
40117a4
bab6277
7b79539
118a0da
94f47c2
419df06
a26387b
677d0e3
5fb3994
8f840bf
d8cc022
8e41a5d
fa9ed35
e687bc2
fd67df8
78295d8
2003370
a2aa8f5
93c8dc0
9c291a9
65895ae
2c3873d
ade9bad
49661d7
f3fae3e
ccd7cd7
bb15aea
c484aa4
3b19672
352aa3f
6c3ddbc
de1f9ae
1c1b258
bf19944
290879b
4f5781e
83f10d5
e2bcdcc
3448e52
b7b4627
d626e36
a02262b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,38 @@ | ||
from abc import abstractmethod | ||
from abc import abstractmethod, ABC | ||
from skglm.utils.validation import check_attrs | ||
|
||
|
||
class BaseSolver(): | ||
"""Base class for solvers.""" | ||
class BaseSolver(ABC): | ||
"""Base class for solvers. | ||
|
||
Attributes | ||
---------- | ||
_datafit_required_attr : list | ||
List of attributes that must be implemented in Datafit. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can make these public @BadrMOUFAD ? Maybe I'm missing a specific reason There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also typo missing "that must BE" here and below There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think these two attributes should be read only. While there is a way to make attributes read only, namely using the property decorator, I believe it adds two much complexity to the code and hence doesn’t serve our goal to make components implementation user-friendly. I opted for the “start with underscore” naming convention to make variables private to signal to the user that these are attributes to not mess up with. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, this makes sense too ! |
||
|
||
_penalty_required_attr : list | ||
List of attributes that must be implemented in Penalty. | ||
|
||
Notes | ||
----- | ||
For required attributes, if an attribute is given as a list of attributes | ||
it means at least one of them should be implemented. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had trouble understanding this, an example help (in which case do we want to check that one of several attributes is present?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I 100% agree with you @mathurinm, I should have accompanied the docs with an example In Fista solver, this _datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar"))
_penalty_required_attr = (("prox_1d", "prox_vec"),) is interpreted as
This is the way I implemented There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you be a bit more explciit like : "if an element of the required attribute list is itself a list, it means that at least of theses attributes must be implemented There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes indeed |
||
For instance, if | ||
|
||
_datafit_required_attr = ( | ||
"get_global_lipschitz", | ||
("gradient", "gradient_scalar") | ||
) | ||
|
||
it mean datafit must implement the methods ``get_global_lipschitz`` | ||
and (``gradient`` or ``gradient_scaler``). | ||
""" | ||
|
||
_datafit_required_attr: list | ||
_penalty_required_attr: list | ||
|
||
@abstractmethod | ||
def solve(self, X, y, datafit, penalty, w_init, Xw_init): | ||
def _solve(self, X, y, datafit, penalty, w_init, Xw_init): | ||
"""Solve an optimization problem. | ||
|
||
Parameters | ||
|
@@ -39,3 +66,51 @@ def solve(self, X, y, datafit, penalty, w_init, Xw_init): | |
stop_crit : float | ||
Value of stopping criterion at convergence. | ||
""" | ||
|
||
def custom_checks(self, X, y, datafit, penalty): | ||
"""Ensure the solver is suited for the `datafit` + `penalty` problem. | ||
|
||
This method includes extra checks to perform | ||
aside from checking attributes compatibility. | ||
|
||
Parameters | ||
---------- | ||
X : array, shape (n_samples, n_features) | ||
Training data. | ||
|
||
y : array, shape (n_samples,) | ||
Target values. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the validation of datafit/penalty depends on the data, for instance when For now, we can settle for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @Badr-MOUFAD here, even if not needed at the moment it's not too hard to see cases where this would happen, and an API change will be painful |
||
|
||
datafit : instance of BaseDatafit | ||
Datafit. | ||
|
||
penalty : instance of BasePenalty | ||
Penalty. | ||
""" | ||
pass | ||
|
||
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None, | ||
*, run_checks=True): | ||
"""Solve the optimization problem after validating its compatibility. | ||
|
||
A proxy of ``_solve`` method that implicitly ensures the compatibility | ||
of ``datafit`` and ``penalty`` with the solver. | ||
|
||
Examples | ||
-------- | ||
>>> ... | ||
>>> coefs, obj_out, stop_crit = solver.solve(X, y, datafit, penalty) | ||
""" | ||
if run_checks: | ||
self._validate(X, y, datafit, penalty) | ||
|
||
return self._solve(X, y, datafit, penalty, w_init, Xw_init) | ||
|
||
def _validate(self, X, y, datafit, penalty): | ||
# execute: `custom_checks` then check attributes | ||
self.custom_checks(X, y, datafit, penalty) | ||
|
||
# do not check for sparse support here, make the check at the solver level | ||
# some solvers like ProxNewton don't require methods for sparse support | ||
check_attrs(datafit, self, self._datafit_required_attr) | ||
check_attrs(penalty, self, self._penalty_required_attr) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
from skglm.solvers.base import BaseSolver | ||
from skglm.solvers.common import construct_grad, construct_grad_sparse | ||
from skglm.utils.prox_funcs import _prox_vec | ||
from skglm.utils.validation import check_attrs | ||
|
||
|
||
class FISTA(BaseSolver): | ||
|
@@ -27,6 +28,9 @@ class FISTA(BaseSolver): | |
https://epubs.siam.org/doi/10.1137/080716542 | ||
""" | ||
|
||
_datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar")) | ||
_penalty_required_attr = (("prox_1d", "prox_vec"),) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does FISTA need prox_1D, it is not used in the code below There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is used in The algorithm works if penalty has either of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got it, thanks ! |
||
|
||
def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0): | ||
self.max_iter = max_iter | ||
self.tol = tol | ||
|
@@ -35,7 +39,7 @@ def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0): | |
self.fit_intercept = False # needed to be passed to GeneralizedLinearEstimator | ||
self.warm_start = False | ||
|
||
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): | ||
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): | ||
p_objs_out = [] | ||
n_samples, n_features = X.shape | ||
all_features = np.arange(n_features) | ||
|
@@ -46,19 +50,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): | |
z = w_init.copy() if w_init is not None else np.zeros(n_features) | ||
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples) | ||
|
||
try: | ||
if X_is_sparse: | ||
lipschitz = datafit.get_global_lipschitz_sparse( | ||
X.data, X.indptr, X.indices, y | ||
) | ||
else: | ||
lipschitz = datafit.get_global_lipschitz(X, y) | ||
except AttributeError as e: | ||
sparse_suffix = '_sparse' if X_is_sparse else '' | ||
|
||
raise Exception( | ||
"Datafit is not compatible with FISTA solver.\n Datafit must " | ||
f"implement `get_global_lipschitz{sparse_suffix}` method") from e | ||
if X_is_sparse: | ||
lipschitz = datafit.get_global_lipschitz_sparse( | ||
X.data, X.indptr, X.indices, y | ||
) | ||
else: | ||
lipschitz = datafit.get_global_lipschitz(X, y) | ||
|
||
for n_iter in range(self.max_iter): | ||
t_old = t_new | ||
|
@@ -111,3 +108,18 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): | |
print(f"Stopping criterion max violation: {stop_crit:.2e}") | ||
break | ||
return w, np.array(p_objs_out), stop_crit | ||
|
||
def custom_checks(self, X, y, datafit, penalty): | ||
# check datafit support sparse data | ||
check_attrs( | ||
datafit, solver=self, | ||
required_attr=self._datafit_required_attr, | ||
support_sparse=issparse(X) | ||
) | ||
|
||
# optimality check | ||
if self.opt_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): | ||
QB3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise AttributeError( | ||
"Penalty must implement `subdiff_distance` " | ||
"to use `opt_strategy='subdiff'` in Fista solver." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how was this working before if the datafit and penalties were not compiled ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They were compiled inside the solver.