-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH - check datafit + penalty
compatibility with solver
#137
ENH - check datafit + penalty
compatibility with solver
#137
Conversation
With this PR, the errors are more verbose: In [1]: from skglm.estimators import GeneralizedLinearEstimator
from skglm.penalties import L0_5
from skglm.datafits import Quadratic, Logistic
from skglm.solvers import ProxNewton, AndersonCD
import numpy as np
In [2]: X = np.random.normal(0, 1, (30, 50))
y = np.random.normal(0, 1, (30,))
In [3]: clf = GeneralizedLinearEstimator(Quadratic(), L0_5(1.), ProxNewton())
In [4]: clf.fit(X, y)
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 clf.fit(X, y)
File ~/Documents/skglm/skglm/estimators.py:241, in GeneralizedLinearEstimator.fit(self, X, y)
238 self.datafit = self.datafit if self.datafit else Quadratic()
239 self.solver = self.solver if self.solver else AndersonCD()
--> 241 return _glm_fit(X, y, self, self.datafit, self.penalty, self.solver)
File ~/Documents/skglm/skglm/estimators.py:29, in _glm_fit(X, y, model, datafit, penalty, solver)
27 is_classif = isinstance(datafit, (Logistic, QuadraticSVC))
28 fit_intercept = solver.fit_intercept
---> 29 validate_solver(solver, datafit, penalty)
31 if is_classif:
32 check_classification_targets(y)
File ~/Documents/skglm/skglm/utils/dispatcher.py:21, in validate_solver(solver, datafit, penalty)
6 """Ensure the solver is suited for the `datafit` + `penalty` problem.
7
8 Parameters
(...)
17 Penalty.
18 """
19 if (isinstance(solver, ProxNewton)
20 and not set(("raw_grad", "raw_hessian")) <= set(dir(datafit))):
---> 21 raise Exception(
22 f"ProwNewton cannot optimize {datafit.__class__.__name__}, since `raw_grad`"
23 " and `raw_hessian` are not implemented.")
24 if ("ws_strategy" in dir(solver) and solver.ws_strategy == "subdiff"
25 and isinstance(penalty, (L0_5, L2_3))):
26 raise Exception(
27 "ws_strategy=`subdiff` is not available for Lp penalties (p < 1). "
28 "Set ws_strategy to `fixpoint`.")
Exception: ProwNewton cannot optimize Quadratic, since `raw_grad` and `raw_hessian` are not implemented. |
_glm_fit
Looks nice @PABannier, this will definitely improve UX! From an API point of view, shouldn't this check be delegated to each solver? This way we don't have one big function, but Such functions could also take care of the initialization (e.g. stepsize computation) which is done on a solver basis. WDYT? |
@mathurinm Yes I think it's cleaner, currently refining the POC. |
This would be a nice addition if we can ship it in the 0.3 release @Badr-MOUFAD , given that we added a few datafits, penalties and solvers ! |
@Badr-MOUFAD the issue popped up in #188, do you have time to take this over ? A simple check, at the beginning of each solver, that the datafit and penalty are supported (eg AndersonCD does not support Gamma datafit) |
Sure, I will resume this PR. |
…into solver_dispatcher
_glm_fit
datafit + penalty
compatibility with solver
Requires #191 to be implemented fit to allow for better checks |
Let's perform checks all the time for now, this will simplify our lives. WDYT?
OK
@QB3 any opinion? |
Co-authored-by: mathurinm <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for everyone 💪
We should be careful in internalizing the jit-compilation of dataffits and penalties as jit-compilation undoes the datafit initialization.
I believe this PR brings several contributions to the API and touches many parts of the codebase. we better merge it and tackle the aforementioned issue in that in a separate PR.
A quick proof-of-concept of a function that checks if the combination
(solver, datafit, penalty)
is supported. Currently we have some edge cases where one can passProxNewton
solver withL0_5
penalty without any error being raised.Pros of this design: the validation rules are centralized and validating a 3-uple is a one-liner in
glm_fit
.Cons: we have to update the rules as we enhance the capabilities of the solver.
All in all, I think it is very valuable to have more verbose errors when fitting estimators (e.g. Ali Rahimi initially passed a combination Quadratic, L2_3, ProxNewton which cannot be optimized at the moment of writing).
Closes #101
Closes #90
Closes #109