From 0154215e427fc78f34e5794922eefac400772156 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sat, 10 Dec 2022 18:52:20 +0100 Subject: [PATCH 01/68] initial commit --- skglm/estimators.py | 2 ++ skglm/utils/dispatcher.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 skglm/utils/dispatcher.py diff --git a/skglm/estimators.py b/skglm/estimators.py index 9f0c1f41d..6164d4e3d 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -16,6 +16,7 @@ from sklearn.preprocessing import LabelEncoder from sklearn.multiclass import OneVsRestClassifier, check_classification_targets +from skglm.utils.dispatcher import validate_solver from skglm.utils.jit_compilation import compiled_clone from skglm.solvers import AndersonCD, MultiTaskBCD from skglm.datafits import Quadratic, Logistic, QuadraticSVC, QuadraticMultiTask @@ -25,6 +26,7 @@ def _glm_fit(X, y, model, datafit, penalty, solver): is_classif = isinstance(datafit, (Logistic, QuadraticSVC)) fit_intercept = solver.fit_intercept + validate_solver(solver, datafit, penalty) if is_classif: check_classification_targets(y) diff --git a/skglm/utils/dispatcher.py b/skglm/utils/dispatcher.py new file mode 100644 index 000000000..cd76c7cb9 --- /dev/null +++ b/skglm/utils/dispatcher.py @@ -0,0 +1,28 @@ +from skglm.penalties import L0_5, L2_3 +from skglm.solvers import ProxNewton + + +def validate_solver(solver, datafit, penalty): + """Ensure the solver is suited for the `datafit` + `penalty` problem. + + Parameters + ---------- + solver : instance of BaseSolver + Solver. + + datafit : instance of BaseDatafit + Datafit. + + penalty : instance of BasePenalty + Penalty. + """ + if (isinstance(solver, ProxNewton) + and not set(("raw_grad", "raw_hessian")) <= set(dir(datafit))): + raise Exception( + f"ProwNewton cannot optimize {datafit.__class__.__name__}, since `raw_grad`" + " and `raw_hessian` are not implemented.") + if ("ws_strategy" in dir(solver) and solver.ws_strategy == "subdiff" + and isinstance(penalty, (L0_5, L2_3))): + raise Exception( + "ws_strategy=`subdiff` is not available for Lp penalties (p < 1). " + "Set ws_strategy to `fixpoint`.") From 3c70b85186d5198bcfb00d97fe212065df2768ea Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sat, 7 Jan 2023 16:43:41 +0100 Subject: [PATCH 02/68] delegated to solve --- skglm/solvers/anderson_cd.py | 7 +++++++ skglm/solvers/base.py | 14 ++++++++++++++ skglm/solvers/fista.py | 3 +++ skglm/solvers/gram_cd.py | 3 +++ skglm/solvers/group_bcd.py | 3 +++ skglm/solvers/group_prox_newton.py | 6 ++++++ skglm/solvers/multitask_bcd.py | 3 +++ skglm/solvers/prox_newton.py | 6 ++++++ skglm/utils/dispatcher.py | 28 ---------------------------- 9 files changed, 45 insertions(+), 28 deletions(-) delete mode 100644 skglm/utils/dispatcher.py diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 147d12c8e..6febd5f25 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -2,6 +2,7 @@ from numba import njit from scipy import sparse from sklearn.utils import check_array +from skglm.penalties import L0_5, L2_3 from skglm.solvers.common import construct_grad, construct_grad_sparse, dist_fix_point from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration @@ -257,6 +258,12 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None, if return_n_iter: results += (n_iters,) return results + + def validate(self, datafit, penalty): + if self.ws_strategy == "subdiff" and isinstance(penalty, (L0_5, L2_3)): + raise Exception( + "ws_strategy=`subdiff` is not available for Lp penalties (p < 1). " + "Set ws_strategy to `fixpoint`.") @njit diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 9b5c5b121..1fa7cb928 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -39,3 +39,17 @@ def solve(self, X, y, datafit, penalty, w_init, Xw_init): stop_crit : float Value of stopping criterion at convergence. """ + + @abstractmethod + def validate(self, datafit, penalty): + """Ensure the solver is suited for the `datafit` + `penalty` problem. + + Parameters + ---------- + datafit : instance of BaseDatafit + Datafit. + + penalty : instance of BasePenalty + Penalty. + """ + diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index e8a2fbfda..4f0fe45bd 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -102,3 +102,6 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): print(f"Stopping criterion max violation: {stop_crit:.2e}") break return w, np.array(p_objs_out), stop_crit + + def validate(self, datafit, penalty): + pass diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 7f17efb71..8f1c1f287 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -150,3 +150,6 @@ def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd): grad += (w[j] - old_w_j) * scaled_gram[:, j] return penalty.subdiff_distance(w, grad, all_features) + + def validate(self, datafit, penalty): + pass diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index 7824e6b63..0fe45971d 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -138,6 +138,9 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, p_objs_out, stop_crit + def validate(self, datafit, penalty): + pass + @njit def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws): diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 929ab853b..3e9584f22 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -142,6 +142,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out.append(p_obj) return w, np.asarray(p_objs_out), stop_crit + def validate(self, datafit, penalty): + if not set(("raw_grad", "raw_hessian")) <= set(dir(datafit)): + raise Exception( + f"ProwNewton cannot optimize {datafit.__class__.__name__}, since" + + "`raw_grad` and `raw_hessian` are not implemented.") + @njit def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 7b7276940..27ac01419 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -222,6 +222,9 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) results += (n_iters,) return results + + def validate(self, datafit, penalty): + pass @njit diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index 0b5b662bc..6d9c1d7e2 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -160,6 +160,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out.append(p_obj) return w, np.asarray(p_objs_out), stop_crit + def validate(self, datafit, penalty): + if not set(("raw_grad", "raw_hessian")) <= set(dir(datafit)): + raise Exception( + f"ProwNewton cannot optimize {datafit.__class__.__name__}, since" + + "`raw_grad` and `raw_hessian` are not implemented.") + @njit def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, diff --git a/skglm/utils/dispatcher.py b/skglm/utils/dispatcher.py deleted file mode 100644 index cd76c7cb9..000000000 --- a/skglm/utils/dispatcher.py +++ /dev/null @@ -1,28 +0,0 @@ -from skglm.penalties import L0_5, L2_3 -from skglm.solvers import ProxNewton - - -def validate_solver(solver, datafit, penalty): - """Ensure the solver is suited for the `datafit` + `penalty` problem. - - Parameters - ---------- - solver : instance of BaseSolver - Solver. - - datafit : instance of BaseDatafit - Datafit. - - penalty : instance of BasePenalty - Penalty. - """ - if (isinstance(solver, ProxNewton) - and not set(("raw_grad", "raw_hessian")) <= set(dir(datafit))): - raise Exception( - f"ProwNewton cannot optimize {datafit.__class__.__name__}, since `raw_grad`" - " and `raw_hessian` are not implemented.") - if ("ws_strategy" in dir(solver) and solver.ws_strategy == "subdiff" - and isinstance(penalty, (L0_5, L2_3))): - raise Exception( - "ws_strategy=`subdiff` is not available for Lp penalties (p < 1). " - "Set ws_strategy to `fixpoint`.") From ee0d29cea88915897b88ce3e2ba1c2e3ba3594c0 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sat, 7 Jan 2023 17:20:03 +0100 Subject: [PATCH 03/68] call `solver.validate` --- skglm/estimators.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/skglm/estimators.py b/skglm/estimators.py index 6164d4e3d..3fc9c8526 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -16,7 +16,6 @@ from sklearn.preprocessing import LabelEncoder from sklearn.multiclass import OneVsRestClassifier, check_classification_targets -from skglm.utils.dispatcher import validate_solver from skglm.utils.jit_compilation import compiled_clone from skglm.solvers import AndersonCD, MultiTaskBCD from skglm.datafits import Quadratic, Logistic, QuadraticSVC, QuadraticMultiTask @@ -26,7 +25,7 @@ def _glm_fit(X, y, model, datafit, penalty, solver): is_classif = isinstance(datafit, (Logistic, QuadraticSVC)) fit_intercept = solver.fit_intercept - validate_solver(solver, datafit, penalty) + solver.validate(datafit, penalty) if is_classif: check_classification_targets(y) From 9abd1497a3ba64d928e66f5cbea83d1e5e9035ed Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 10:55:25 +0200 Subject: [PATCH 04/68] add validate method to solvers --- skglm/estimators.py | 1 - skglm/solvers/anderson_cd.py | 7 ++----- skglm/solvers/gram_cd.py | 5 +++-- skglm/solvers/group_prox_newton.py | 5 +---- skglm/solvers/lbfgs.py | 3 +++ skglm/solvers/multitask_bcd.py | 2 +- 6 files changed, 10 insertions(+), 13 deletions(-) diff --git a/skglm/estimators.py b/skglm/estimators.py index 65d9abe96..37c5f6ad8 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -28,7 +28,6 @@ def _glm_fit(X, y, model, datafit, penalty, solver): is_classif = isinstance(datafit, (Logistic, QuadraticSVC)) fit_intercept = solver.fit_intercept - solver.validate(datafit, penalty) if is_classif: check_classification_targets(y) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 6febd5f25..3f2c86ee4 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -258,12 +258,9 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None, if return_n_iter: results += (n_iters,) return results - + def validate(self, datafit, penalty): - if self.ws_strategy == "subdiff" and isinstance(penalty, (L0_5, L2_3)): - raise Exception( - "ws_strategy=`subdiff` is not available for Lp penalties (p < 1). " - "Set ws_strategy to `fixpoint`.") + pass @njit diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index c06d4d92e..4c99fb900 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -154,5 +154,6 @@ def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd): return penalty.subdiff_distance(w, grad, all_features) - def validate(self, datafit, penalty): - pass + +def validate(self, datafit, penalty): + pass diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 3e9584f22..401dd6fc6 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -143,10 +143,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.asarray(p_objs_out), stop_crit def validate(self, datafit, penalty): - if not set(("raw_grad", "raw_hessian")) <= set(dir(datafit)): - raise Exception( - f"ProwNewton cannot optimize {datafit.__class__.__name__}, since" - + "`raw_grad` and `raw_hessian` are not implemented.") + pass @njit diff --git a/skglm/solvers/lbfgs.py b/skglm/solvers/lbfgs.py index 5e7e03051..fc859cc6e 100644 --- a/skglm/solvers/lbfgs.py +++ b/skglm/solvers/lbfgs.py @@ -102,3 +102,6 @@ def callback_post_iter(w_k): stop_crit = norm(result.jac, ord=np.inf) return w, np.asarray(p_objs_out), stop_crit + + def validate(self, datafit, penalty): + pass diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 27ac01419..96f1bfaae 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -222,7 +222,7 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) results += (n_iters,) return results - + def validate(self, datafit, penalty): pass From d20229ae7e693cd1a622d6eafc84572c54aa3f67 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 11:48:40 +0200 Subject: [PATCH 05/68] implem validation logic --- skglm/solvers/prox_newton.py | 11 +++++++---- skglm/utils/validation.py | 13 +++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index e57414266..dffb32d65 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -7,6 +7,8 @@ from sklearn.exceptions import ConvergenceWarning from skglm.utils.sparse_ops import _sparse_xj_dot +from skglm.utils.validation import check_obj_solver_compatibility + EPS_TOL = 0.3 MAX_CD_ITER = 20 @@ -48,6 +50,9 @@ class ProxNewton(BaseSolver): code: https://github.com/tbjohns/BlitzL1 """ + _datafit_required_attr = ("raw_grad", "raw_hessian") + _penalty_required_attr = ("prox_1d", "subdiff_distance") + def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, fit_intercept=True, warm_start=False, verbose=0): self.p0 = p0 @@ -174,10 +179,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.asarray(p_objs_out), stop_crit def validate(self, datafit, penalty): - if not set(("raw_grad", "raw_hessian")) <= set(dir(datafit)): - raise Exception( - f"ProwNewton cannot optimize {datafit.__class__.__name__}, since" - + "`raw_grad` and `raw_hessian` are not implemented.") + check_obj_solver_compatibility(datafit, ProxNewton._datafit_required_attr) + check_obj_solver_compatibility(penalty, ProxNewton._penalty_required_attr) @njit diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index 0da22df40..bfcc5a4f8 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -23,3 +23,16 @@ def check_group_compatible(obj): f"'{obj_name}' is not block-separable. " f"Missing '{attr}' attribute." ) + + +def check_obj_solver_compatibility(obj, required_attr): + missing_attrs = [f"`{attr}`" for attr in required_attr if not hasattr(obj, attr)] + + if len(missing_attrs): + required_attr = ' and '.join(f"`{attr}`" for attr in required_attr) + + raise AttributeError( + "Object not compatible with solver. " + f"It must implement {' and '.join(required_attr)} \n" + f"Missing {' and '.join(missing_attrs)}." + ) From 310c572ff144c0ec3fe45c095b8c4f9e05d727ee Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 13:54:23 +0200 Subject: [PATCH 06/68] implem attribute validation for solvers --- skglm/solvers/anderson_cd.py | 8 ++++++-- skglm/solvers/fista.py | 3 +++ skglm/solvers/gram_cd.py | 19 +++++++++++++++---- skglm/solvers/group_bcd.py | 13 +++++++++++-- skglm/solvers/group_prox_newton.py | 16 +++++++++++----- skglm/solvers/lbfgs.py | 7 ++++++- skglm/solvers/multitask_bcd.py | 7 ++++++- skglm/solvers/prox_newton.py | 2 +- 8 files changed, 59 insertions(+), 16 deletions(-) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 3f2c86ee4..d461552be 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -2,10 +2,10 @@ from numba import njit from scipy import sparse from sklearn.utils import check_array -from skglm.penalties import L0_5, L2_3 from skglm.solvers.common import construct_grad, construct_grad_sparse, dist_fix_point from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration +from skglm.utils.validation import check_obj_solver_compatibility class AndersonCD(BaseSolver): @@ -46,6 +46,9 @@ class AndersonCD(BaseSolver): code: https://github.com/mathurinm/andersoncd """ + _datafit_required_attr = ("gradient_scalar",) + _penalty_required_attr = ("prox_1d", "subdiff_distance") + def __init__(self, max_iter=50, max_epochs=50_000, p0=10, tol=1e-4, ws_strategy="subdiff", fit_intercept=True, warm_start=False, verbose=0): @@ -260,7 +263,8 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None, return results def validate(self, datafit, penalty): - pass + check_obj_solver_compatibility(datafit, AndersonCD._datafit_required_attr) + check_obj_solver_compatibility(penalty, AndersonCD._penalty_required_attr) @njit diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index 993170063..9c2820429 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -27,6 +27,9 @@ class FISTA(BaseSolver): https://epubs.siam.org/doi/10.1137/080716542 """ + _datafit_required_attr = ("init_global_lipschitz",) + _penalty_required_attr = ("subdiff_distance",) + def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0): self.max_iter = max_iter self.tol = tol diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 4c99fb900..6c83f67a4 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -2,8 +2,11 @@ import numpy as np from numba import njit from scipy.sparse import issparse + +from skglm.datafits import Quadratic from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration +from skglm.utils.validation import check_obj_solver_compatibility class GramCD(BaseSolver): @@ -48,6 +51,9 @@ class GramCD(BaseSolver): Amount of verbosity. 0/False is silent. """ + _datafit_required_attr = ("gradient_scalar",) + _penalty_required_attr = ("prox_1d", "subdiff_distance") + def __init__(self, max_iter=100, use_acc=True, greedy_cd=True, tol=1e-4, fit_intercept=True, warm_start=False, verbose=0): self.max_iter = max_iter @@ -131,6 +137,15 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out.append(p_obj) return w, np.array(p_objs_out), stop_crit + def validate(self, datafit, penalty): + if datafit.__class__ is not Quadratic: + raise AttributeError( + f"`GramCD` supports only `Quadratic` datafit. got {datafit}" + ) + + check_obj_solver_compatibility(datafit, GramCD._datafit_required_attr) + check_obj_solver_compatibility(penalty, GramCD._penalty_required_attr) + @njit def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd): @@ -153,7 +168,3 @@ def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd): grad += (w[j] - old_w_j) * scaled_gram[:, j] return penalty.subdiff_distance(w, grad, all_features) - - -def validate(self, datafit, penalty): - pass diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index 0fe45971d..34f511c9c 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -3,7 +3,9 @@ from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration -from skglm.utils.validation import check_group_compatible +from skglm.utils.validation import ( + check_group_compatible, check_obj_solver_compatibility +) class GroupBCD(BaseSolver): @@ -35,6 +37,9 @@ class GroupBCD(BaseSolver): Amount of verbosity. 0/False is silent. """ + _datafit_required_attr = ("gradient_g",) + _penalty_required_attr = ("subdiff_distance", "prox_1group") + def __init__(self, max_iter=1000, max_epochs=100, p0=10, tol=1e-4, fit_intercept=False, warm_start=False, verbose=0): self.max_iter = max_iter @@ -139,7 +144,11 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, p_objs_out, stop_crit def validate(self, datafit, penalty): - pass + check_obj_solver_compatibility(datafit, GroupBCD._datafit_required_attr) + check_obj_solver_compatibility(penalty, GroupBCD._penalty_required_attr) + + check_group_compatible(datafit) + check_group_compatible(penalty) @njit diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 401dd6fc6..96f4c6370 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -2,7 +2,9 @@ from numba import njit from numpy.linalg import norm from skglm.solvers.base import BaseSolver -from skglm.utils.validation import check_group_compatible +from skglm.utils.validation import ( + check_group_compatible, check_obj_solver_compatibility +) EPS_TOL = 0.3 MAX_CD_ITER = 20 @@ -41,6 +43,9 @@ class GroupProxNewton(BaseSolver): code: https://github.com/tbjohns/BlitzL1 """ + _datafit_required_attr = ("raw_grad", "raw_hessian") + _penalty_required_attr = ("prox_1d", "subdiff_distance") + def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, fit_intercept=False, warm_start=False, verbose=0): self.p0 = p0 @@ -52,9 +57,6 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, self.verbose = verbose def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): - check_group_compatible(datafit) - check_group_compatible(penalty) - fit_intercept = self.fit_intercept n_samples, n_features = X.shape grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices @@ -143,7 +145,11 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.asarray(p_objs_out), stop_crit def validate(self, datafit, penalty): - pass + check_obj_solver_compatibility(datafit, GroupProxNewton._datafit_required_attr) + check_obj_solver_compatibility(penalty, GroupProxNewton._penalty_required_attr) + + check_group_compatible(datafit) + check_group_compatible(penalty) @njit diff --git a/skglm/solvers/lbfgs.py b/skglm/solvers/lbfgs.py index fc859cc6e..0112b9c8b 100644 --- a/skglm/solvers/lbfgs.py +++ b/skglm/solvers/lbfgs.py @@ -7,6 +7,7 @@ from scipy.sparse import issparse from skglm.solvers import BaseSolver +from skglm.utils.validation import check_obj_solver_compatibility class LBFGS(BaseSolver): @@ -27,6 +28,9 @@ class LBFGS(BaseSolver): Amount of verbosity. 0/False is silent. """ + _datafit_required_attr = ("gradient",) + _penalty_required_attr = ("gradient",) + def __init__(self, max_iter=50, tol=1e-4, verbose=False): self.max_iter = max_iter self.tol = tol @@ -104,4 +108,5 @@ def callback_post_iter(w_k): return w, np.asarray(p_objs_out), stop_crit def validate(self, datafit, penalty): - pass + check_obj_solver_compatibility(datafit, LBFGS._datafit_required_attr) + check_obj_solver_compatibility(penalty, LBFGS._penalty_required_attr) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 96f1bfaae..cdee91768 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -4,11 +4,15 @@ from numpy.linalg import norm from sklearn.utils import check_array from skglm.solvers.base import BaseSolver +from skglm.utils.validation import check_obj_solver_compatibility class MultiTaskBCD(BaseSolver): """Block coordinate descent solver for multi-task problems.""" + _datafit_required_attr = ("gradient_j",) + _penalty_required_attr = ("subdiff_distance", "prox_1feat") + def __init__(self, max_iter=100, max_epochs=50_000, p0=10, tol=1e-6, use_acc=True, ws_strategy="subdiff", fit_intercept=True, warm_start=False, verbose=0): @@ -224,7 +228,8 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) return results def validate(self, datafit, penalty): - pass + check_obj_solver_compatibility(datafit, MultiTaskBCD._datafit_required_attr) + check_obj_solver_compatibility(penalty, MultiTaskBCD._penalty_required_attr) @njit diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index dffb32d65..eebe12092 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -51,7 +51,7 @@ class ProxNewton(BaseSolver): """ _datafit_required_attr = ("raw_grad", "raw_hessian") - _penalty_required_attr = ("prox_1d", "subdiff_distance") + _penalty_required_attr = ("subdiff_distance", "gradient_g") def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, fit_intercept=True, warm_start=False, verbose=0): From 1872e4e4fad2c349b31c2a33ceb952f9cceda01f Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 13:59:49 +0200 Subject: [PATCH 07/68] validation ``PDCD_WS`` --- skglm/experimental/pdcd_ws.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index b81a68f5f..96eaab409 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -6,6 +6,8 @@ from numba import njit from skglm.utils.jit_compilation import compiled_clone +from skglm.utils.validation import check_obj_solver_compatibility + from sklearn.exceptions import ConvergenceWarning @@ -78,6 +80,9 @@ class PDCD_WS: https://arxiv.org/abs/2204.07826 """ + _datafit_required_attr = ('prox_conjugate',) + _penalty_required_attr = ("prox_1d",) + def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None, p0=100, tol=1e-6, verbose=False): self.max_iter = max_iter @@ -197,23 +202,13 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty, break @staticmethod - def _validate_init(datafit_, penalty_): - # validate datafit - missing_attrs = [] - for attr in ('prox_conjugate', 'subdiff_distance'): - if not hasattr(datafit_, attr): - missing_attrs.append(f"`{attr}`") - - if len(missing_attrs): - raise AttributeError( - "Datafit is not compatible with PDCD_WS solver.\n" - "Datafit must implement `prox_conjugate` and `subdiff_distance`.\n" - f"Missing {' and '.join(missing_attrs)}." - ) + def _validate_init(datafit, penalty): + check_obj_solver_compatibility(datafit, PDCD_WS._datafit_required_attr) + check_obj_solver_compatibility(penalty, PDCD_WS._penalty_required_attr) # jit compile classes - compiled_datafit = compiled_clone(datafit_) - compiled_penalty = compiled_clone(penalty_) + compiled_datafit = compiled_clone(datafit) + compiled_penalty = compiled_clone(penalty) return compiled_datafit, compiled_penalty From cd1ba1cc9e49b471fe86b20f5890d3a4a3de741d Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 14:08:52 +0200 Subject: [PATCH 08/68] fix trailing spaces --- skglm/solvers/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 1fa7cb928..990dcaedb 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -39,7 +39,7 @@ def solve(self, X, y, datafit, penalty, w_init, Xw_init): stop_crit : float Value of stopping criterion at convergence. """ - + @abstractmethod def validate(self, datafit, penalty): """Ensure the solver is suited for the `datafit` + `penalty` problem. @@ -52,4 +52,3 @@ def validate(self, datafit, penalty): penalty : instance of BasePenalty Penalty. """ - From 4c9b4d44fd8849e9acac312289d09d54aaa324ee Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 14:09:07 +0200 Subject: [PATCH 09/68] add docs to ``check_obj_solver_compatibility`` --- skglm/utils/validation.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index bfcc5a4f8..c452930a1 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -26,6 +26,22 @@ def check_group_compatible(obj): def check_obj_solver_compatibility(obj, required_attr): + """Check whether datafit or penalty is compatible with a solver. + + Parameters + ---------- + obj : Instance of Datafit or Penalty + The instance Datafit (or Penalty) to check. + + required_attr : List or tuple of strings + The attributes that ``obj`` must have. + + Raises + ------ + AttributeError + if any of the attribute in ``required_attr`` is missing + from ``obj`` attributes. + """ missing_attrs = [f"`{attr}`" for attr in required_attr if not hasattr(obj, attr)] if len(missing_attrs): From c3d01c479fa84a0a7c9f59ff407e2de21dc6ec13 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 14:32:11 +0200 Subject: [PATCH 10/68] add validation glm_fit --- skglm/estimators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skglm/estimators.py b/skglm/estimators.py index 37c5f6ad8..65d9abe96 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -28,6 +28,7 @@ def _glm_fit(X, y, model, datafit, penalty, solver): is_classif = isinstance(datafit, (Logistic, QuadraticSVC)) fit_intercept = solver.fit_intercept + solver.validate(datafit, penalty) if is_classif: check_classification_targets(y) From eacea147fbcdfe341c8b6ccf456566f2eed3bec7 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 14:32:38 +0200 Subject: [PATCH 11/68] fix Error logs --- skglm/solvers/gram_cd.py | 2 +- skglm/utils/validation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 6c83f67a4..6ba5b9ba6 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -140,7 +140,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): def validate(self, datafit, penalty): if datafit.__class__ is not Quadratic: raise AttributeError( - f"`GramCD` supports only `Quadratic` datafit. got {datafit}" + f"`GramCD` supports only `Quadratic` datafit, got {datafit}" ) check_obj_solver_compatibility(datafit, GramCD._datafit_required_attr) diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index c452930a1..3bdcdc7a9 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -45,7 +45,7 @@ def check_obj_solver_compatibility(obj, required_attr): missing_attrs = [f"`{attr}`" for attr in required_attr if not hasattr(obj, attr)] if len(missing_attrs): - required_attr = ' and '.join(f"`{attr}`" for attr in required_attr) + required_attr = [f"`{attr}`" for attr in required_attr] raise AttributeError( "Object not compatible with solver. " From 573cb78eda4b317ab5ed1992098ab21a43fdc805 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 14:33:05 +0200 Subject: [PATCH 12/68] fix prox solvers attribute names --- skglm/solvers/group_prox_newton.py | 2 +- skglm/solvers/prox_newton.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 96f4c6370..6cf5f2f62 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -44,7 +44,7 @@ class GroupProxNewton(BaseSolver): """ _datafit_required_attr = ("raw_grad", "raw_hessian") - _penalty_required_attr = ("prox_1d", "subdiff_distance") + _penalty_required_attr = ("subdiff_distance", "prox_1group") def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, fit_intercept=False, warm_start=False, verbose=0): diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index eebe12092..72b7fcc40 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -51,7 +51,7 @@ class ProxNewton(BaseSolver): """ _datafit_required_attr = ("raw_grad", "raw_hessian") - _penalty_required_attr = ("subdiff_distance", "gradient_g") + _penalty_required_attr = ("subdiff_distance", "prox_1d") def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, fit_intercept=True, warm_start=False, verbose=0): From 8dedf18498e7da0cf03e8f6779f7b82a12d0dfad Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 15:12:42 +0200 Subject: [PATCH 13/68] add ``initialize`` to required attributes --- skglm/solvers/anderson_cd.py | 2 +- skglm/solvers/gram_cd.py | 2 +- skglm/solvers/group_bcd.py | 2 +- skglm/solvers/multitask_bcd.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index d461552be..dfbb99a61 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -46,7 +46,7 @@ class AndersonCD(BaseSolver): code: https://github.com/mathurinm/andersoncd """ - _datafit_required_attr = ("gradient_scalar",) + _datafit_required_attr = ("initialize", "gradient_scalar") _penalty_required_attr = ("prox_1d", "subdiff_distance") def __init__(self, max_iter=50, max_epochs=50_000, p0=10, diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 6ba5b9ba6..06cfdd33b 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -138,7 +138,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.array(p_objs_out), stop_crit def validate(self, datafit, penalty): - if datafit.__class__ is not Quadratic: + if not isinstance(datafit, Quadratic): raise AttributeError( f"`GramCD` supports only `Quadratic` datafit, got {datafit}" ) diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index 34f511c9c..df9f49aee 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -37,7 +37,7 @@ class GroupBCD(BaseSolver): Amount of verbosity. 0/False is silent. """ - _datafit_required_attr = ("gradient_g",) + _datafit_required_attr = ("initialize", "gradient_g") _penalty_required_attr = ("subdiff_distance", "prox_1group") def __init__(self, max_iter=1000, max_epochs=100, p0=10, tol=1e-4, diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index cdee91768..c0a218439 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -10,7 +10,7 @@ class MultiTaskBCD(BaseSolver): """Block coordinate descent solver for multi-task problems.""" - _datafit_required_attr = ("gradient_j",) + _datafit_required_attr = ("initialize", "gradient_j") _penalty_required_attr = ("subdiff_distance", "prox_1feat") def __init__(self, max_iter=100, max_epochs=50_000, p0=10, tol=1e-6, From 19fbe63c6b40873fc8f212f266d86c5f58e664bb Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 15:18:56 +0200 Subject: [PATCH 14/68] add change to what's new --- doc/changes/0.4.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/changes/0.4.rst b/doc/changes/0.4.rst index ed09b0101..5ee01ea9b 100644 --- a/doc/changes/0.4.rst +++ b/doc/changes/0.4.rst @@ -1,5 +1,7 @@ .. _changes_0_4: Version 0.4 (in progress) ---------------------------- +------------------------- + - Add support for weights and positive coefficients to :ref:`MCPRegression Estimator ` (PR: :gh:`184`) +- Check compatibility between ``datafit + penalty`` and solver (PR :gh:`137`) From 3b65445578020a326605a0a46e516fb658c398a3 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 18 Oct 2023 15:38:15 +0200 Subject: [PATCH 15/68] formatting & Fista validation --- skglm/solvers/fista.py | 4 +++- skglm/utils/validation.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index 9c2820429..d06faace7 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -3,6 +3,7 @@ from skglm.solvers.base import BaseSolver from skglm.solvers.common import construct_grad, construct_grad_sparse from skglm.utils.prox_funcs import _prox_vec +from skglm.utils.validation import check_obj_solver_compatibility class FISTA(BaseSolver): @@ -115,4 +116,5 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.array(p_objs_out), stop_crit def validate(self, datafit, penalty): - pass + check_obj_solver_compatibility(datafit, FISTA._datafit_required_attr) + check_obj_solver_compatibility(penalty, FISTA._penalty_required_attr) diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index 3bdcdc7a9..89e090510 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -49,6 +49,6 @@ def check_obj_solver_compatibility(obj, required_attr): raise AttributeError( "Object not compatible with solver. " - f"It must implement {' and '.join(required_attr)} \n" + f"It must implement {' and '.join(required_attr)}\n" f"Missing {' and '.join(missing_attrs)}." ) From ddc67b65f1a1c962daf9585f40a8e9bc1f5f6e3c Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 26 Oct 2023 13:43:14 +0200 Subject: [PATCH 16/68] get names obj and solver in ``check_obj_solver_attr_compatibility`` --- skglm/utils/validation.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index 89e090510..4b14b383b 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -1,3 +1,5 @@ +import re + def check_group_compatible(obj): """Check whether ``obj`` is compatible with ``bcd_solver``. @@ -25,14 +27,17 @@ def check_group_compatible(obj): ) -def check_obj_solver_compatibility(obj, required_attr): - """Check whether datafit or penalty is compatible with a solver. +def check_obj_solver_attr_compatibility(obj, solver, required_attr): + """Check whether datafit or penalty is compatible with solver. Parameters ---------- obj : Instance of Datafit or Penalty The instance Datafit (or Penalty) to check. + solver : Instance of Solver + The instance of Solver to check. + required_attr : List or tuple of strings The attributes that ``obj`` must have. @@ -47,8 +52,14 @@ def check_obj_solver_compatibility(obj, required_attr): if len(missing_attrs): required_attr = [f"`{attr}`" for attr in required_attr] + # get name obj and solver + name_matcher = re.compile(r"\.(\w+)'>") + + obj_name = name_matcher.search(str(obj.__class__)).group(1) + solver_name = name_matcher.search(str(solver.__class__)).group(1) + raise AttributeError( - "Object not compatible with solver. " + f"{obj_name} is not compatible with {solver_name}. " f"It must implement {' and '.join(required_attr)}\n" f"Missing {' and '.join(missing_attrs)}." ) From 5d750701379d38e7fd271e7c4f2d2d84c62a8b99 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 26 Oct 2023 13:43:50 +0200 Subject: [PATCH 17/68] follow up changes --- skglm/experimental/pdcd_ws.py | 9 ++++----- skglm/solvers/anderson_cd.py | 8 ++++---- skglm/solvers/fista.py | 6 +++--- skglm/solvers/gram_cd.py | 6 +++--- skglm/solvers/group_bcd.py | 6 +++--- skglm/solvers/group_prox_newton.py | 6 +++--- skglm/solvers/lbfgs.py | 6 +++--- skglm/solvers/multitask_bcd.py | 6 +++--- skglm/solvers/prox_newton.py | 6 +++--- 9 files changed, 29 insertions(+), 30 deletions(-) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index 96eaab409..677019d57 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -6,7 +6,7 @@ from numba import njit from skglm.utils.jit_compilation import compiled_clone -from skglm.utils.validation import check_obj_solver_compatibility +from skglm.utils.validation import check_obj_solver_attr_compatibility from sklearn.exceptions import ConvergenceWarning @@ -201,10 +201,9 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty, if stop_crit_in <= tol_in: break - @staticmethod - def _validate_init(datafit, penalty): - check_obj_solver_compatibility(datafit, PDCD_WS._datafit_required_attr) - check_obj_solver_compatibility(penalty, PDCD_WS._penalty_required_attr) + def _validate_init(self, datafit, penalty): + check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) # jit compile classes compiled_datafit = compiled_clone(datafit) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 3efe46bc8..15f52a880 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -5,7 +5,7 @@ from skglm.solvers.common import construct_grad, construct_grad_sparse, dist_fix_point from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration -from skglm.utils.validation import check_obj_solver_compatibility +from skglm.utils.validation import check_obj_solver_attr_compatibility class AndersonCD(BaseSolver): @@ -46,7 +46,7 @@ class AndersonCD(BaseSolver): code: https://github.com/mathurinm/andersoncd """ - _datafit_required_attr = ("initialize", "gradient_scalar") + _datafit_required_attr = ("get_lipschitz", "gradient_scalar") _penalty_required_attr = ("prox_1d", "subdiff_distance") def __init__(self, max_iter=50, max_epochs=50_000, p0=10, @@ -270,8 +270,8 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None, return results def validate(self, datafit, penalty): - check_obj_solver_compatibility(datafit, AndersonCD._datafit_required_attr) - check_obj_solver_compatibility(penalty, AndersonCD._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) @njit diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index b725ebaf1..512e13515 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -3,7 +3,7 @@ from skglm.solvers.base import BaseSolver from skglm.solvers.common import construct_grad, construct_grad_sparse from skglm.utils.prox_funcs import _prox_vec -from skglm.utils.validation import check_obj_solver_compatibility +from skglm.utils.validation import check_obj_solver_attr_compatibility class FISTA(BaseSolver): @@ -117,5 +117,5 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.array(p_objs_out), stop_crit def validate(self, datafit, penalty): - check_obj_solver_compatibility(datafit, FISTA._datafit_required_attr) - check_obj_solver_compatibility(penalty, FISTA._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self._penalty_required_attr) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 06cfdd33b..40cb85eda 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -6,7 +6,7 @@ from skglm.datafits import Quadratic from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration -from skglm.utils.validation import check_obj_solver_compatibility +from skglm.utils.validation import check_obj_solver_attr_compatibility class GramCD(BaseSolver): @@ -143,8 +143,8 @@ def validate(self, datafit, penalty): f"`GramCD` supports only `Quadratic` datafit, got {datafit}" ) - check_obj_solver_compatibility(datafit, GramCD._datafit_required_attr) - check_obj_solver_compatibility(penalty, GramCD._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) @njit diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index df9f49aee..22fa8da37 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -4,7 +4,7 @@ from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration from skglm.utils.validation import ( - check_group_compatible, check_obj_solver_compatibility + check_group_compatible, check_obj_solver_attr_compatibility ) @@ -144,8 +144,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, p_objs_out, stop_crit def validate(self, datafit, penalty): - check_obj_solver_compatibility(datafit, GroupBCD._datafit_required_attr) - check_obj_solver_compatibility(penalty, GroupBCD._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) check_group_compatible(datafit) check_group_compatible(penalty) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 6cf5f2f62..7b12a00e3 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -3,7 +3,7 @@ from numpy.linalg import norm from skglm.solvers.base import BaseSolver from skglm.utils.validation import ( - check_group_compatible, check_obj_solver_compatibility + check_group_compatible, check_obj_solver_attr_compatibility ) EPS_TOL = 0.3 @@ -145,8 +145,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.asarray(p_objs_out), stop_crit def validate(self, datafit, penalty): - check_obj_solver_compatibility(datafit, GroupProxNewton._datafit_required_attr) - check_obj_solver_compatibility(penalty, GroupProxNewton._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self._penalty_required_attr) check_group_compatible(datafit) check_group_compatible(penalty) diff --git a/skglm/solvers/lbfgs.py b/skglm/solvers/lbfgs.py index 0112b9c8b..2ac2a322a 100644 --- a/skglm/solvers/lbfgs.py +++ b/skglm/solvers/lbfgs.py @@ -7,7 +7,7 @@ from scipy.sparse import issparse from skglm.solvers import BaseSolver -from skglm.utils.validation import check_obj_solver_compatibility +from skglm.utils.validation import check_obj_solver_attr_compatibility class LBFGS(BaseSolver): @@ -108,5 +108,5 @@ def callback_post_iter(w_k): return w, np.asarray(p_objs_out), stop_crit def validate(self, datafit, penalty): - check_obj_solver_compatibility(datafit, LBFGS._datafit_required_attr) - check_obj_solver_compatibility(penalty, LBFGS._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 9629fcd0d..9f2a1210b 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -4,7 +4,7 @@ from numpy.linalg import norm from sklearn.utils import check_array from skglm.solvers.base import BaseSolver -from skglm.utils.validation import check_obj_solver_compatibility +from skglm.utils.validation import check_obj_solver_attr_compatibility class MultiTaskBCD(BaseSolver): @@ -234,8 +234,8 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) return results def validate(self, datafit, penalty): - check_obj_solver_compatibility(datafit, MultiTaskBCD._datafit_required_attr) - check_obj_solver_compatibility(penalty, MultiTaskBCD._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self._penalty_required_attr) @njit diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index 72b7fcc40..404352aae 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -7,7 +7,7 @@ from sklearn.exceptions import ConvergenceWarning from skglm.utils.sparse_ops import _sparse_xj_dot -from skglm.utils.validation import check_obj_solver_compatibility +from skglm.utils.validation import check_obj_solver_attr_compatibility EPS_TOL = 0.3 @@ -179,8 +179,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.asarray(p_objs_out), stop_crit def validate(self, datafit, penalty): - check_obj_solver_compatibility(datafit, ProxNewton._datafit_required_attr) - check_obj_solver_compatibility(penalty, ProxNewton._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self._penalty_required_attr) @njit From ce118ca1effaa4f87ba00d2229ae1fa9a644dd2b Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 2 Nov 2023 14:21:18 +0100 Subject: [PATCH 18/68] specify required attributes in solvers --- skglm/experimental/pdcd_ws.py | 20 ++++++++++---------- skglm/solvers/gram_cd.py | 2 +- skglm/solvers/group_bcd.py | 4 ++-- skglm/solvers/group_prox_newton.py | 2 +- skglm/solvers/multitask_bcd.py | 4 ++-- skglm/solvers/prox_newton.py | 2 +- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index 677019d57..ef0d0c07d 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -5,13 +5,14 @@ from scipy.sparse import issparse from numba import njit +from skglm.solvers import BaseSolver from skglm.utils.jit_compilation import compiled_clone from skglm.utils.validation import check_obj_solver_attr_compatibility from sklearn.exceptions import ConvergenceWarning -class PDCD_WS: +class PDCD_WS(BaseSolver): r"""Primal-Dual Coordinate Descent solver with working sets. It solves @@ -94,9 +95,14 @@ def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None, def solve(self, X, y, datafit_, penalty_, w_init=None, Xw_init=None): if issparse(X): - raise ValueError("Sparse matrices are not yet support in PDCD_WS solver.") + raise ValueError("Sparse matrices are not yet support in `PDCD_WS` solver.") + + self.validate(datafit_, penalty_) + + # jit compile classes + datafit = compiled_clone(datafit_) + penalty = compiled_clone(penalty_) - datafit, penalty = PDCD_WS._validate_init(datafit_, penalty_) n_samples, n_features = X.shape # init steps @@ -201,16 +207,10 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty, if stop_crit_in <= tol_in: break - def _validate_init(self, datafit, penalty): + def validate(self, datafit, penalty): check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) - # jit compile classes - compiled_datafit = compiled_clone(datafit) - compiled_penalty = compiled_clone(penalty) - - return compiled_datafit, compiled_penalty - @njit def _scores_primal(X, w, z, penalty, primal_steps, ws): diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 40cb85eda..b0acad95f 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -51,7 +51,7 @@ class GramCD(BaseSolver): Amount of verbosity. 0/False is silent. """ - _datafit_required_attr = ("gradient_scalar",) + _datafit_required_attr = () _penalty_required_attr = ("prox_1d", "subdiff_distance") def __init__(self, max_iter=100, use_acc=True, greedy_cd=True, tol=1e-4, diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index c2cc53063..0db918897 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -37,8 +37,8 @@ class GroupBCD(BaseSolver): Amount of verbosity. 0/False is silent. """ - _datafit_required_attr = ("initialize", "gradient_g") - _penalty_required_attr = ("subdiff_distance", "prox_1group") + _datafit_required_attr = ("get_lipschitz", "gradient_g") + _penalty_required_attr = ("prox_1group", "subdiff_distance") def __init__(self, max_iter=1000, max_epochs=100, p0=10, tol=1e-4, fit_intercept=False, warm_start=False, verbose=0): diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 7b12a00e3..29e4dfa2c 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -44,7 +44,7 @@ class GroupProxNewton(BaseSolver): """ _datafit_required_attr = ("raw_grad", "raw_hessian") - _penalty_required_attr = ("subdiff_distance", "prox_1group") + _penalty_required_attr = ("prox_1group", "subdiff_distance") def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, fit_intercept=False, warm_start=False, verbose=0): diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 1abc1c08f..75d61ee5f 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -10,8 +10,8 @@ class MultiTaskBCD(BaseSolver): """Block coordinate descent solver for multi-task problems.""" - _datafit_required_attr = ("initialize", "gradient_j") - _penalty_required_attr = ("subdiff_distance", "prox_1feat") + _datafit_required_attr = ("get_lipschitz", "gradient_j") + _penalty_required_attr = ("prox_1feat", "subdiff_distance") def __init__(self, max_iter=100, max_epochs=50_000, p0=10, tol=1e-6, use_acc=True, ws_strategy="subdiff", fit_intercept=True, diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index 404352aae..7b8eb6ea2 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -51,7 +51,7 @@ class ProxNewton(BaseSolver): """ _datafit_required_attr = ("raw_grad", "raw_hessian") - _penalty_required_attr = ("subdiff_distance", "prox_1d") + _penalty_required_attr = ("prox_1d", "subdiff_distance") def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, fit_intercept=True, warm_start=False, verbose=0): From f38d12e0a1ac71431ccbe491ecee34d5c56b6e13 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 2 Nov 2023 14:27:55 +0100 Subject: [PATCH 19/68] add ``*_required_attr`` in ``BaseSolver`` --- skglm/solvers/base.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 990dcaedb..4c82cfb14 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -2,7 +2,19 @@ class BaseSolver(): - """Base class for solvers.""" + """Base class for solvers. + + Attributes + ---------- + _datafit_required_attr : list of str + List of attributes that must implemented in Datafit. + + _penalty_required_attr : list of str + List of attributes that must implemented in Penalty. + """ + + _datafit_required_attr: list + _penalty_required_attr: list @abstractmethod def solve(self, X, y, datafit, penalty, w_init, Xw_init): From d1aeb4caaa8142e82a7da87d8f3c58b68d3f1273 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 2 Nov 2023 14:40:29 +0100 Subject: [PATCH 20/68] pass in solver to ``check_obj_solver_attr_compatibility`` --- skglm/solvers/fista.py | 4 ++-- skglm/solvers/group_prox_newton.py | 4 ++-- skglm/solvers/multitask_bcd.py | 4 ++-- skglm/solvers/prox_newton.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index 512e13515..6d1fa9aa7 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -117,5 +117,5 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.array(p_objs_out), stop_crit def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 29e4dfa2c..eda06af02 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -145,8 +145,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.asarray(p_objs_out), stop_crit def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) check_group_compatible(datafit) check_group_compatible(penalty) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 75d61ee5f..1b8bb3d17 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -234,8 +234,8 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) return results def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) @njit diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index 7b8eb6ea2..abf8fa47a 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -179,8 +179,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.asarray(p_objs_out), stop_crit def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self._penalty_required_attr) + check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) @njit From b4e2d8e773e72637c72b4d10e1ef60e3dc3b4765 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 2 Nov 2023 15:22:40 +0100 Subject: [PATCH 21/68] handle solvers that supports ``ws_strategy='subdiff_distance'`` --- skglm/solvers/anderson_cd.py | 7 +++++-- skglm/solvers/fista.py | 7 +++++-- skglm/solvers/multitask_bcd.py | 7 +++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 629bcf725..6b78edefb 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -49,7 +49,7 @@ class AndersonCD(BaseSolver): """ _datafit_required_attr = ("get_lipschitz", "gradient_scalar") - _penalty_required_attr = ("prox_1d", "subdiff_distance") + _penalty_required_attr = ("prox_1d",) def __init__(self, max_iter=50, max_epochs=50_000, p0=10, tol=1e-4, ws_strategy="subdiff", fit_intercept=True, @@ -273,7 +273,10 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None, def validate(self, datafit, penalty): check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) + + if self.ws_strategy == "subdiff": + penalty_required_attr = self._penalty_required_attr + ("subdiff_distance",) + check_obj_solver_attr_compatibility(penalty, self, penalty_required_attr) @njit diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index 6d1fa9aa7..c13e155f6 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -29,7 +29,7 @@ class FISTA(BaseSolver): """ _datafit_required_attr = ("init_global_lipschitz",) - _penalty_required_attr = ("subdiff_distance",) + _penalty_required_attr = () def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0): self.max_iter = max_iter @@ -118,4 +118,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): def validate(self, datafit, penalty): check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) + + if self.opt_strategy == "subdiff": + penalty_required_attr = self._penalty_required_attr + ("subdiff_distance",) + check_obj_solver_attr_compatibility(penalty, self, penalty_required_attr) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 1b8bb3d17..6ed858c3f 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -11,7 +11,7 @@ class MultiTaskBCD(BaseSolver): """Block coordinate descent solver for multi-task problems.""" _datafit_required_attr = ("get_lipschitz", "gradient_j") - _penalty_required_attr = ("prox_1feat", "subdiff_distance") + _penalty_required_attr = ("prox_1feat",) def __init__(self, max_iter=100, max_epochs=50_000, p0=10, tol=1e-6, use_acc=True, ws_strategy="subdiff", fit_intercept=True, @@ -235,7 +235,10 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) def validate(self, datafit, penalty): check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) + + if self.ws_strategy == "subdiff": + penalty_required_attr = self._penalty_required_attr + ("subdiff_distance",) + check_obj_solver_attr_compatibility(penalty, self, penalty_required_attr) @njit From 92f5d3b19c326e0ada734be1fc35e61255086ea1 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 2 Nov 2023 16:32:30 +0100 Subject: [PATCH 22/68] revert solver with subdiff check --- skglm/solvers/anderson_cd.py | 7 ++----- skglm/solvers/fista.py | 7 ++----- skglm/solvers/multitask_bcd.py | 5 +---- 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 6b78edefb..e8de4b004 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -49,7 +49,7 @@ class AndersonCD(BaseSolver): """ _datafit_required_attr = ("get_lipschitz", "gradient_scalar") - _penalty_required_attr = ("prox_1d",) + _penalty_required_attr = () def __init__(self, max_iter=50, max_epochs=50_000, p0=10, tol=1e-4, ws_strategy="subdiff", fit_intercept=True, @@ -273,10 +273,7 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None, def validate(self, datafit, penalty): check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - - if self.ws_strategy == "subdiff": - penalty_required_attr = self._penalty_required_attr + ("subdiff_distance",) - check_obj_solver_attr_compatibility(penalty, self, penalty_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) @njit diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index c13e155f6..ce84b6a4b 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -28,7 +28,7 @@ class FISTA(BaseSolver): https://epubs.siam.org/doi/10.1137/080716542 """ - _datafit_required_attr = ("init_global_lipschitz",) + _datafit_required_attr = ("get_global_lipschitz",) _penalty_required_attr = () def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0): @@ -118,7 +118,4 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): def validate(self, datafit, penalty): check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - - if self.opt_strategy == "subdiff": - penalty_required_attr = self._penalty_required_attr + ("subdiff_distance",) - check_obj_solver_attr_compatibility(penalty, self, penalty_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 6ed858c3f..8143e31b3 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -235,10 +235,7 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) def validate(self, datafit, penalty): check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - - if self.ws_strategy == "subdiff": - penalty_required_attr = self._penalty_required_attr + ("subdiff_distance",) - check_obj_solver_attr_compatibility(penalty, self, penalty_required_attr) + check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) @njit From af7ae2cf70af2dd61847781912c7c4aa2d618288 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 2 Nov 2023 16:33:31 +0100 Subject: [PATCH 23/68] unittest validation && abc fixes --- skglm/penalties/base.py | 4 ---- skglm/penalties/non_separable.py | 7 +------ skglm/tests/test_validation.py | 25 +++++++++++++++++++++++++ 3 files changed, 26 insertions(+), 10 deletions(-) create mode 100644 skglm/tests/test_validation.py diff --git a/skglm/penalties/base.py b/skglm/penalties/base.py index b45254b71..fe83759c7 100644 --- a/skglm/penalties/base.py +++ b/skglm/penalties/base.py @@ -28,10 +28,6 @@ def params_to_dict(self): def value(self, w): """Value of penalty at vector w.""" - @abstractmethod - def prox_1d(self, value, stepsize, j): - """Proximal operator of penalty for feature j.""" - @abstractmethod def subdiff_distance(self, w, grad, ws): """Distance of negative gradient to subdifferential at w for features in `ws`. diff --git a/skglm/penalties/non_separable.py b/skglm/penalties/non_separable.py index c27079323..85f0f5831 100644 --- a/skglm/penalties/non_separable.py +++ b/skglm/penalties/non_separable.py @@ -49,12 +49,7 @@ def prox_vec(self, x, stepsize): return np.sign(x) * prox - def prox_1d(self, value, stepsize, j): - raise ValueError( - "No coordinate-wise proximal operator for SLOPE. Use `prox_vec` instead." - ) - def subdiff_distance(self, w, grad, ws): - return ValueError( + raise ValueError( "No subdifferential distance for SLOPE. Use `opt_strategy='fixpoint'`" ) diff --git a/skglm/tests/test_validation.py b/skglm/tests/test_validation.py new file mode 100644 index 000000000..4a71e61cd --- /dev/null +++ b/skglm/tests/test_validation.py @@ -0,0 +1,25 @@ +import pytest +from skglm.datafits import Quadratic, Poisson +from skglm.penalties import L1 +from skglm.solvers import FISTA, ProxNewton +from skglm.utils.jit_compilation import compiled_clone + + +def test_datafit_penalty_solver_compatibility(): + with pytest.raises( + AttributeError, match="Missing `raw_grad` and `raw_hessian`" + ): + ProxNewton().validate( + compiled_clone(Quadratic()), compiled_clone(L1(1.)) + ) + + with pytest.raises( + AttributeError, match="Missing `get_global_lipschitz`" + ): + FISTA().validate( + compiled_clone(Poisson()), compiled_clone(L1(1.)) + ) + + +if __name__ == "__main__": + pass From 850f04f759491c82d3e453553749a4827277dc52 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 11:00:11 +0100 Subject: [PATCH 24/68] implicit validation in ``__call__`` --- skglm/solvers/anderson_cd.py | 5 ----- skglm/solvers/base.py | 9 +++++++++ skglm/solvers/fista.py | 5 ----- skglm/solvers/gram_cd.py | 4 ---- skglm/solvers/group_bcd.py | 7 +------ skglm/solvers/group_prox_newton.py | 8 ++------ skglm/solvers/lbfgs.py | 5 ----- skglm/solvers/multitask_bcd.py | 5 ----- skglm/solvers/prox_newton.py | 5 ----- 9 files changed, 12 insertions(+), 41 deletions(-) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index e8de4b004..520ee0233 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -7,7 +7,6 @@ ) from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration -from skglm.utils.validation import check_obj_solver_attr_compatibility class AndersonCD(BaseSolver): @@ -271,10 +270,6 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None, results += (n_iters,) return results - def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) - @njit def _cd_epoch(X, y, w, Xw, lc, datafit, penalty, ws): diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 4c82cfb14..40790a9bd 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from skglm.utils.validation import check_obj_solver_attr_compatibility class BaseSolver(): @@ -64,3 +65,11 @@ def validate(self, datafit, penalty): penalty : instance of BasePenalty Penalty. """ + + def __call__(self, X, y, datafit, penalty, w_init, Xw_init, **kwargs): + check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) + check_obj_solver_attr_compatibility(datafit, self, self._penalty_required_attr) + + self.validate(datafit, penalty) + + self.solve(X, y, datafit, penalty, w_init, Xw_init, **kwargs) diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index ce84b6a4b..c94918ab9 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -3,7 +3,6 @@ from skglm.solvers.base import BaseSolver from skglm.solvers.common import construct_grad, construct_grad_sparse from skglm.utils.prox_funcs import _prox_vec -from skglm.utils.validation import check_obj_solver_attr_compatibility class FISTA(BaseSolver): @@ -115,7 +114,3 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): print(f"Stopping criterion max violation: {stop_crit:.2e}") break return w, np.array(p_objs_out), stop_crit - - def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index b0acad95f..7fadd5444 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -6,7 +6,6 @@ from skglm.datafits import Quadratic from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration -from skglm.utils.validation import check_obj_solver_attr_compatibility class GramCD(BaseSolver): @@ -143,9 +142,6 @@ def validate(self, datafit, penalty): f"`GramCD` supports only `Quadratic` datafit, got {datafit}" ) - check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) - @njit def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd): diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index 0db918897..abb74ec75 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -3,9 +3,7 @@ from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration -from skglm.utils.validation import ( - check_group_compatible, check_obj_solver_attr_compatibility -) +from skglm.utils.validation import check_group_compatible class GroupBCD(BaseSolver): @@ -146,9 +144,6 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, p_objs_out, stop_crit def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) - check_group_compatible(datafit) check_group_compatible(penalty) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index eda06af02..e2d8d86cd 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -2,9 +2,8 @@ from numba import njit from numpy.linalg import norm from skglm.solvers.base import BaseSolver -from skglm.utils.validation import ( - check_group_compatible, check_obj_solver_attr_compatibility -) +from skglm.utils.validation import check_group_compatible + EPS_TOL = 0.3 MAX_CD_ITER = 20 @@ -145,9 +144,6 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.asarray(p_objs_out), stop_crit def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) - check_group_compatible(datafit) check_group_compatible(penalty) diff --git a/skglm/solvers/lbfgs.py b/skglm/solvers/lbfgs.py index 2ac2a322a..d8a2172fc 100644 --- a/skglm/solvers/lbfgs.py +++ b/skglm/solvers/lbfgs.py @@ -7,7 +7,6 @@ from scipy.sparse import issparse from skglm.solvers import BaseSolver -from skglm.utils.validation import check_obj_solver_attr_compatibility class LBFGS(BaseSolver): @@ -106,7 +105,3 @@ def callback_post_iter(w_k): stop_crit = norm(result.jac, ord=np.inf) return w, np.asarray(p_objs_out), stop_crit - - def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 8143e31b3..c997fe011 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -4,7 +4,6 @@ from numpy.linalg import norm from sklearn.utils import check_array from skglm.solvers.base import BaseSolver -from skglm.utils.validation import check_obj_solver_attr_compatibility class MultiTaskBCD(BaseSolver): @@ -233,10 +232,6 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) return results - def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) - @njit def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws): diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index e4e5c7e62..ccb45fcf8 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -8,7 +8,6 @@ from sklearn.exceptions import ConvergenceWarning from skglm.utils.sparse_ops import _sparse_xj_dot -from skglm.utils.validation import check_obj_solver_attr_compatibility EPS_TOL = 0.3 @@ -198,10 +197,6 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): ) return w, np.asarray(p_objs_out), stop_crit - def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) - @njit def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, From 84a2cf8746c7d6a2082dbac02ca17124d0278c33 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 13:59:17 +0100 Subject: [PATCH 25/68] validation logic revisited --- skglm/utils/validation.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index 4b14b383b..f5cab86f1 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -27,7 +27,7 @@ def check_group_compatible(obj): ) -def check_obj_solver_attr_compatibility(obj, solver, required_attr): +def check_obj_solver_attr(obj, solver, required_attr): """Check whether datafit or penalty is compatible with solver. Parameters @@ -47,10 +47,18 @@ def check_obj_solver_attr_compatibility(obj, solver, required_attr): if any of the attribute in ``required_attr`` is missing from ``obj`` attributes. """ - missing_attrs = [f"`{attr}`" for attr in required_attr if not hasattr(obj, attr)] + missing_attrs = [] + for attr in required_attr: + attributes = attr if not isinstance(attr, str) else (attr,) + + for a in attributes: + if hasattr(obj, a): + break + else: + missing_attrs.append(_join_attrs_with_or(attributes)) if len(missing_attrs): - required_attr = [f"`{attr}`" for attr in required_attr] + required_attr = [_join_attrs_with_or(attrs) for attrs in required_attr] # get name obj and solver name_matcher = re.compile(r"\.(\w+)'>") @@ -63,3 +71,15 @@ def check_obj_solver_attr_compatibility(obj, solver, required_attr): f"It must implement {' and '.join(required_attr)}\n" f"Missing {' and '.join(missing_attrs)}." ) + + +def _join_attrs_with_or(attrs): + # + if isinstance(attrs, str): + return f"`{attrs}`" + + if len(attrs) == 1: + return f"`{attrs[0]}`" + + out = " or ".join([f"`{a}`" for a in attrs]) + return f'"{out}"' From 05e5cffa2fa62aa5be0c2a143bdc23d379b53173 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 13:59:38 +0100 Subject: [PATCH 26/68] BaseSolver as abstract class --- skglm/solvers/base.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 40790a9bd..2b87d4293 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -1,8 +1,8 @@ -from abc import abstractmethod -from skglm.utils.validation import check_obj_solver_attr_compatibility +from abc import abstractmethod, ABC +from skglm.utils.validation import check_obj_solver_attr -class BaseSolver(): +class BaseSolver(ABC): """Base class for solvers. Attributes @@ -53,8 +53,7 @@ def solve(self, X, y, datafit, penalty, w_init, Xw_init): Value of stopping criterion at convergence. """ - @abstractmethod - def validate(self, datafit, penalty): + def custom_compatibility_check(self, datafit, penalty): """Ensure the solver is suited for the `datafit` + `penalty` problem. Parameters @@ -66,10 +65,14 @@ def validate(self, datafit, penalty): Penalty. """ - def __call__(self, X, y, datafit, penalty, w_init, Xw_init, **kwargs): - check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(datafit, self, self._penalty_required_attr) + def __call__(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + """""" + self._validate(datafit, penalty) + self.solve(X, y, datafit, penalty, w_init, Xw_init) - self.validate(datafit, penalty) + def _validate(self, datafit, penalty): + # + check_obj_solver_attr(datafit, self, self._datafit_required_attr) + check_obj_solver_attr(datafit, self, self._penalty_required_attr) - self.solve(X, y, datafit, penalty, w_init, Xw_init, **kwargs) + self.custom_compatibility_check(datafit, penalty) From 4aefc8c4943109540479c05efe957fc7899294c6 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 14:03:26 +0100 Subject: [PATCH 27/68] add required attributes --- skglm/solvers/anderson_cd.py | 2 +- skglm/solvers/fista.py | 4 ++-- skglm/solvers/prox_newton.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 520ee0233..0f55a1ad7 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -48,7 +48,7 @@ class AndersonCD(BaseSolver): """ _datafit_required_attr = ("get_lipschitz", "gradient_scalar") - _penalty_required_attr = () + _penalty_required_attr = ("prox_1d",) def __init__(self, max_iter=50, max_epochs=50_000, p0=10, tol=1e-4, ws_strategy="subdiff", fit_intercept=True, diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index c94918ab9..d6322dca0 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -27,8 +27,8 @@ class FISTA(BaseSolver): https://epubs.siam.org/doi/10.1137/080716542 """ - _datafit_required_attr = ("get_global_lipschitz",) - _penalty_required_attr = () + _datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar")) + _penalty_required_attr = (("prox_1d", "prox_vec"),) def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0): self.max_iter = max_iter diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index ccb45fcf8..5a04df095 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -54,7 +54,7 @@ class ProxNewton(BaseSolver): """ _datafit_required_attr = ("raw_grad", "raw_hessian") - _penalty_required_attr = ("prox_1d", "subdiff_distance") + _penalty_required_attr = ("prox_1d",) def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, ws_strategy="subdiff", fit_intercept=True, warm_start=False, From 626fe77f2ac83bddd9e95ca57839bbff7ed04a91 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 14:25:23 +0100 Subject: [PATCH 28/68] pass & cleanups --- skglm/experimental/pdcd_ws.py | 5 ----- skglm/penalties/non_separable.py | 5 ----- skglm/solvers/base.py | 17 +++++++++++++---- skglm/solvers/gram_cd.py | 2 +- skglm/solvers/group_bcd.py | 2 +- skglm/solvers/group_prox_newton.py | 2 +- skglm/solvers/prox_newton.py | 1 - 7 files changed, 16 insertions(+), 18 deletions(-) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index ef0d0c07d..13bfc5034 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -7,7 +7,6 @@ from numba import njit from skglm.solvers import BaseSolver from skglm.utils.jit_compilation import compiled_clone -from skglm.utils.validation import check_obj_solver_attr_compatibility from sklearn.exceptions import ConvergenceWarning @@ -207,10 +206,6 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty, if stop_crit_in <= tol_in: break - def validate(self, datafit, penalty): - check_obj_solver_attr_compatibility(datafit, self, self._datafit_required_attr) - check_obj_solver_attr_compatibility(penalty, self, self._penalty_required_attr) - @njit def _scores_primal(X, w, z, penalty, primal_steps, ws): diff --git a/skglm/penalties/non_separable.py b/skglm/penalties/non_separable.py index 85f0f5831..58f0b8c2e 100644 --- a/skglm/penalties/non_separable.py +++ b/skglm/penalties/non_separable.py @@ -48,8 +48,3 @@ def prox_vec(self, x, stepsize): prox[sorted_indices] = prox_SLOPE(abs_x[sorted_indices], alphas * stepsize) return np.sign(x) * prox - - def subdiff_distance(self, w, grad, ws): - raise ValueError( - "No subdifferential distance for SLOPE. Use `opt_strategy='fixpoint'`" - ) diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 2b87d4293..8f87449be 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -53,7 +53,7 @@ def solve(self, X, y, datafit, penalty, w_init, Xw_init): Value of stopping criterion at convergence. """ - def custom_compatibility_check(self, datafit, penalty): + def custom_compatibility_check(self, X, y, datafit, penalty): """Ensure the solver is suited for the `datafit` + `penalty` problem. Parameters @@ -66,12 +66,21 @@ def custom_compatibility_check(self, datafit, penalty): """ def __call__(self, X, y, datafit, penalty, w_init=None, Xw_init=None): - """""" + """Solve the optimization problem after validating its compatibility. + + A proxy of ``solve`` method that implicitly ensures the compatibility + of datafit and penalty with the solver. + + Examples + -------- + >>> ... + >>> coefs, obj_out, stop_crit = solver(X, y, datafit, penalty) + """ self._validate(datafit, penalty) self.solve(X, y, datafit, penalty, w_init, Xw_init) - def _validate(self, datafit, penalty): - # + def _validate(self, X, y, datafit, penalty): + # execute both attributes checks and `custom_compatibility_check` check_obj_solver_attr(datafit, self, self._datafit_required_attr) check_obj_solver_attr(datafit, self, self._penalty_required_attr) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 7fadd5444..7445d3a00 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -136,7 +136,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out.append(p_obj) return w, np.array(p_objs_out), stop_crit - def validate(self, datafit, penalty): + def custom_compatibility_check(self, X, y, datafit): if not isinstance(datafit, Quadratic): raise AttributeError( f"`GramCD` supports only `Quadratic` datafit, got {datafit}" diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index abb74ec75..901092d43 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -143,7 +143,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, p_objs_out, stop_crit - def validate(self, datafit, penalty): + def custom_compatibility_check(self, X, y, datafit, penalty): check_group_compatible(datafit) check_group_compatible(penalty) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index e2d8d86cd..6329ed4a0 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -143,7 +143,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out.append(p_obj) return w, np.asarray(p_objs_out), stop_crit - def validate(self, datafit, penalty): + def custom_compatibility_check(self, X, y, datafit, penalty): check_group_compatible(datafit) check_group_compatible(penalty) diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index 5a04df095..529c3923b 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -9,7 +9,6 @@ from sklearn.exceptions import ConvergenceWarning from skglm.utils.sparse_ops import _sparse_xj_dot - EPS_TOL = 0.3 MAX_CD_ITER = 20 MAX_BACKTRACK_ITER = 20 From b50ae91e9ed1f709f76423a458290c4f703eb3e8 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 15:14:01 +0100 Subject: [PATCH 29/68] unittest && docs --- skglm/estimators.py | 1 - skglm/experimental/pdcd_ws.py | 2 -- skglm/solvers/base.py | 24 +++++++++++++++++++----- skglm/tests/test_validation.py | 12 ++++++++---- skglm/utils/validation.py | 5 +++-- 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/skglm/estimators.py b/skglm/estimators.py index 65d9abe96..37c5f6ad8 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -28,7 +28,6 @@ def _glm_fit(X, y, model, datafit, penalty, solver): is_classif = isinstance(datafit, (Logistic, QuadraticSVC)) fit_intercept = solver.fit_intercept - solver.validate(datafit, penalty) if is_classif: check_classification_targets(y) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index 13bfc5034..f0b0f7223 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -96,8 +96,6 @@ def solve(self, X, y, datafit_, penalty_, w_init=None, Xw_init=None): if issparse(X): raise ValueError("Sparse matrices are not yet support in `PDCD_WS` solver.") - self.validate(datafit_, penalty_) - # jit compile classes datafit = compiled_clone(datafit_) penalty = compiled_clone(penalty_) diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 8f87449be..39bc55d04 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -7,11 +7,16 @@ class BaseSolver(ABC): Attributes ---------- - _datafit_required_attr : list of str + _datafit_required_attr : list List of attributes that must implemented in Datafit. - _penalty_required_attr : list of str + _penalty_required_attr : list List of attributes that must implemented in Penalty. + + Notes + ----- + For required attributes, if an attribute is given as a list of attributes + it means at least one of them should be implemented. """ _datafit_required_attr: list @@ -56,8 +61,17 @@ def solve(self, X, y, datafit, penalty, w_init, Xw_init): def custom_compatibility_check(self, X, y, datafit, penalty): """Ensure the solver is suited for the `datafit` + `penalty` problem. + This method includes extra checks to perform + aside from checking attributes compatibility. + Parameters ---------- + X : array, shape (n_samples, n_features) + Training data. + + y : array, shape (n_samples,) + Target values. + datafit : instance of BaseDatafit Datafit. @@ -69,7 +83,7 @@ def __call__(self, X, y, datafit, penalty, w_init=None, Xw_init=None): """Solve the optimization problem after validating its compatibility. A proxy of ``solve`` method that implicitly ensures the compatibility - of datafit and penalty with the solver. + of ``datafit`` and ``penalty`` with the solver. Examples -------- @@ -81,7 +95,7 @@ def __call__(self, X, y, datafit, penalty, w_init=None, Xw_init=None): def _validate(self, X, y, datafit, penalty): # execute both attributes checks and `custom_compatibility_check` + self.custom_compatibility_check(X, y, datafit, penalty) + check_obj_solver_attr(datafit, self, self._datafit_required_attr) check_obj_solver_attr(datafit, self, self._penalty_required_attr) - - self.custom_compatibility_check(datafit, penalty) diff --git a/skglm/tests/test_validation.py b/skglm/tests/test_validation.py index 4a71e61cd..600e4e874 100644 --- a/skglm/tests/test_validation.py +++ b/skglm/tests/test_validation.py @@ -3,21 +3,25 @@ from skglm.penalties import L1 from skglm.solvers import FISTA, ProxNewton from skglm.utils.jit_compilation import compiled_clone +from skglm.utils.data import make_correlated_data def test_datafit_penalty_solver_compatibility(): + n_samples, n_features = 10, 20 + X, y, _ = make_correlated_data(n_samples, n_features, X_density=1.) + with pytest.raises( AttributeError, match="Missing `raw_grad` and `raw_hessian`" ): - ProxNewton().validate( - compiled_clone(Quadratic()), compiled_clone(L1(1.)) + ProxNewton()._validate( + X, y, compiled_clone(Quadratic()), compiled_clone(L1(1.)) ) with pytest.raises( AttributeError, match="Missing `get_global_lipschitz`" ): - FISTA().validate( - compiled_clone(Poisson()), compiled_clone(L1(1.)) + FISTA()._validate( + X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) ) diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index f5cab86f1..8372d0f27 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -51,6 +51,8 @@ def check_obj_solver_attr(obj, solver, required_attr): for attr in required_attr: attributes = attr if not isinstance(attr, str) else (attr,) + # if `attr` is a list check that at least one of them + # is within `obj` attributes for a in attributes: if hasattr(obj, a): break @@ -74,7 +76,6 @@ def check_obj_solver_attr(obj, solver, required_attr): def _join_attrs_with_or(attrs): - # if isinstance(attrs, str): return f"`{attrs}`" @@ -82,4 +83,4 @@ def _join_attrs_with_or(attrs): return f"`{attrs[0]}`" out = " or ".join([f"`{a}`" for a in attrs]) - return f'"{out}"' + return f"({out})" From e81409f284ea82a2e82ead3bb8dd2f37722bde20 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 16:30:20 +0100 Subject: [PATCH 30/68] validations for sparse data support --- skglm/solvers/anderson_cd.py | 9 +++++++++ skglm/solvers/fista.py | 28 +++++++++++++++------------- skglm/solvers/lbfgs.py | 9 +++++++++ skglm/solvers/multitask_bcd.py | 9 +++++++++ skglm/tests/test_validation.py | 15 +++++++++++---- skglm/utils/validation.py | 31 ++++++++++++++++++++----------- 6 files changed, 73 insertions(+), 28 deletions(-) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 0f55a1ad7..bd327c498 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -7,6 +7,7 @@ ) from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration +from skglm.utils.validation import check_obj_solver_attr class AndersonCD(BaseSolver): @@ -270,6 +271,14 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None, results += (n_iters,) return results + def custom_compatibility_check(self, X, y, datafit, penalty): + # check datafit support sparse data + check_obj_solver_attr( + datafit, solver=self, + required_attr=self._datafit_required_attr, + support_sparse=sparse.issparse(X) + ) + @njit def _cd_epoch(X, y, w, Xw, lc, datafit, penalty, ws): diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index d6322dca0..1365f2077 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -3,6 +3,7 @@ from skglm.solvers.base import BaseSolver from skglm.solvers.common import construct_grad, construct_grad_sparse from skglm.utils.prox_funcs import _prox_vec +from skglm.utils.validation import check_obj_solver_attr class FISTA(BaseSolver): @@ -49,19 +50,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): z = w_init.copy() if w_init is not None else np.zeros(n_features) Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples) - try: - if X_is_sparse: - lipschitz = datafit.get_global_lipschitz_sparse( - X.data, X.indptr, X.indices, y - ) - else: - lipschitz = datafit.get_global_lipschitz(X, y) - except AttributeError as e: - sparse_suffix = '_sparse' if X_is_sparse else '' - - raise Exception( - "Datafit is not compatible with FISTA solver.\n Datafit must " - f"implement `get_global_lipschitz{sparse_suffix}` method") from e + if X_is_sparse: + lipschitz = datafit.get_global_lipschitz_sparse( + X.data, X.indptr, X.indices, y + ) + else: + lipschitz = datafit.get_global_lipschitz(X, y) for n_iter in range(self.max_iter): t_old = t_new @@ -114,3 +108,11 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): print(f"Stopping criterion max violation: {stop_crit:.2e}") break return w, np.array(p_objs_out), stop_crit + + def custom_compatibility_check(self, X, y, datafit, penalty): + # check datafit support sparse data + check_obj_solver_attr( + datafit, solver=self, + required_attr=self._datafit_required_attr, + support_sparse=issparse(X) + ) diff --git a/skglm/solvers/lbfgs.py b/skglm/solvers/lbfgs.py index d8a2172fc..df49e3cfe 100644 --- a/skglm/solvers/lbfgs.py +++ b/skglm/solvers/lbfgs.py @@ -7,6 +7,7 @@ from scipy.sparse import issparse from skglm.solvers import BaseSolver +from skglm.utils.validation import check_obj_solver_attr class LBFGS(BaseSolver): @@ -105,3 +106,11 @@ def callback_post_iter(w_k): stop_crit = norm(result.jac, ord=np.inf) return w, np.asarray(p_objs_out), stop_crit + + def custom_compatibility_check(self, X, y, datafit, penalty): + # check datafit support sparse data + check_obj_solver_attr( + datafit, solver=self, + required_attr=self._datafit_required_attr, + support_sparse=issparse(X) + ) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index c997fe011..ebbfb907c 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -4,6 +4,7 @@ from numpy.linalg import norm from sklearn.utils import check_array from skglm.solvers.base import BaseSolver +from skglm.utils.validation import check_obj_solver_attr class MultiTaskBCD(BaseSolver): @@ -232,6 +233,14 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) return results + def custom_compatibility_check(self, X, y, datafit, penalty): + # check datafit support sparse data + check_obj_solver_attr( + datafit, solver=self, + required_attr=self._datafit_required_attr, + support_sparse=sparse.issparse(X) + ) + @njit def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws): diff --git a/skglm/tests/test_validation.py b/skglm/tests/test_validation.py index 600e4e874..e92e3133a 100644 --- a/skglm/tests/test_validation.py +++ b/skglm/tests/test_validation.py @@ -7,21 +7,28 @@ def test_datafit_penalty_solver_compatibility(): - n_samples, n_features = 10, 20 - X, y, _ = make_correlated_data(n_samples, n_features, X_density=1.) + X_sparse, y, _ = make_correlated_data(n_samples=3, n_features=5, X_density=0.5) + X_dense = X_sparse.todense() with pytest.raises( AttributeError, match="Missing `raw_grad` and `raw_hessian`" ): ProxNewton()._validate( - X, y, compiled_clone(Quadratic()), compiled_clone(L1(1.)) + X_dense, y, compiled_clone(Quadratic()), compiled_clone(L1(1.)) ) with pytest.raises( AttributeError, match="Missing `get_global_lipschitz`" ): FISTA()._validate( - X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) + X_dense, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) + ) + + with pytest.raises( + AttributeError, match="Missing `get_global_lipschitz`" + ): + FISTA()._validate( + X_dense, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) ) diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index 8372d0f27..645eb10c5 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -1,6 +1,9 @@ import re +SPARSE_SUFFIX = "_sparse" + + def check_group_compatible(obj): """Check whether ``obj`` is compatible with ``bcd_solver``. @@ -27,7 +30,7 @@ def check_group_compatible(obj): ) -def check_obj_solver_attr(obj, solver, required_attr): +def check_obj_solver_attr(obj, solver, required_attr, support_sparse=False): """Check whether datafit or penalty is compatible with solver. Parameters @@ -41,6 +44,9 @@ def check_obj_solver_attr(obj, solver, required_attr): required_attr : List or tuple of strings The attributes that ``obj`` must have. + support_sparse : bool, default False + If ``True`` adds a ``SPARSE_SUFFIX`` to check compatibility with sparse data. + Raises ------ AttributeError @@ -48,19 +54,21 @@ def check_obj_solver_attr(obj, solver, required_attr): from ``obj`` attributes. """ missing_attrs = [] + suffix = SPARSE_SUFFIX if support_sparse else "" + + # if `attr` is a list check that at least one of them + # is within `obj` attributes for attr in required_attr: attributes = attr if not isinstance(attr, str) else (attr,) - # if `attr` is a list check that at least one of them - # is within `obj` attributes for a in attributes: - if hasattr(obj, a): + if hasattr(obj, f"{a}{suffix}"): break else: - missing_attrs.append(_join_attrs_with_or(attributes)) + missing_attrs.append(_join_attrs_with_or(attributes, suffix)) if len(missing_attrs): - required_attr = [_join_attrs_with_or(attrs) for attrs in required_attr] + required_attr = [_join_attrs_with_or(attrs, suffix) for attrs in required_attr] # get name obj and solver name_matcher = re.compile(r"\.(\w+)'>") @@ -69,18 +77,19 @@ def check_obj_solver_attr(obj, solver, required_attr): solver_name = name_matcher.search(str(solver.__class__)).group(1) raise AttributeError( - f"{obj_name} is not compatible with {solver_name}. " + f"{obj_name} is not compatible with {solver_name}" + " with sparse data. " if support_sparse else ". " f"It must implement {' and '.join(required_attr)}\n" f"Missing {' and '.join(missing_attrs)}." ) -def _join_attrs_with_or(attrs): +def _join_attrs_with_or(attrs, suffix=""): if isinstance(attrs, str): - return f"`{attrs}`" + return f"`{attrs}{suffix}`" if len(attrs) == 1: - return f"`{attrs[0]}`" + return f"`{attrs[0]}{suffix}`" - out = " or ".join([f"`{a}`" for a in attrs]) + out = " or ".join([f"`{a}{suffix}`" for a in attrs]) return f"({out})" From 39fa7b9257135e0917fa3cf4219617a8c880f001 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 16:45:55 +0100 Subject: [PATCH 31/68] validate ``subdiff`` and ``fixpoint`` --- skglm/solvers/anderson_cd.py | 7 +++++++ skglm/solvers/fista.py | 7 +++++++ skglm/solvers/multitask_bcd.py | 7 +++++++ skglm/solvers/prox_newton.py | 8 ++++++++ 4 files changed, 29 insertions(+) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index bd327c498..464dd2fb6 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -279,6 +279,13 @@ def custom_compatibility_check(self, X, y, datafit, penalty): support_sparse=sparse.issparse(X) ) + # ws strategy + if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): + raise AttributeError( + "Penalty must implement `subdiff_distance` " + "to use self.ws_strategy='subdiff'." + ) + @njit def _cd_epoch(X, y, w, Xw, lc, datafit, penalty, ws): diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index 1365f2077..33e2be46e 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -116,3 +116,10 @@ def custom_compatibility_check(self, X, y, datafit, penalty): required_attr=self._datafit_required_attr, support_sparse=issparse(X) ) + + # optimality check + if self.opt_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): + raise AttributeError( + "Penalty must implement `subdiff_distance` " + "to use self.opt_strategy='subdiff'." + ) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index ebbfb907c..097cabb0f 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -241,6 +241,13 @@ def custom_compatibility_check(self, X, y, datafit, penalty): support_sparse=sparse.issparse(X) ) + # ws strategy + if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): + raise AttributeError( + "Penalty must implement `subdiff_distance` " + "to use self.ws_strategy='subdiff'." + ) + @njit def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws): diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index 529c3923b..e542d0bf5 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -196,6 +196,14 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): ) return w, np.asarray(p_objs_out), stop_crit + def custom_compatibility_check(self, X, y, datafit, penalty): + # ws strategy + if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): + raise AttributeError( + "Penalty must implement `subdiff_distance` " + "to use self.ws_strategy='subdiff'." + ) + @njit def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, From 40117a434eeab0dec22d6a607b27669cd6086ff3 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 16:49:25 +0100 Subject: [PATCH 32/68] sparse support in group solvers --- skglm/experimental/pdcd_ws.py | 9 ++++++--- skglm/solvers/group_bcd.py | 6 ++++++ skglm/solvers/group_prox_newton.py | 7 +++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index f0b0f7223..a6ec2b206 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -93,9 +93,6 @@ def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None, self.verbose = verbose def solve(self, X, y, datafit_, penalty_, w_init=None, Xw_init=None): - if issparse(X): - raise ValueError("Sparse matrices are not yet support in `PDCD_WS` solver.") - # jit compile classes datafit = compiled_clone(datafit_) penalty = compiled_clone(penalty_) @@ -204,6 +201,12 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty, if stop_crit_in <= tol_in: break + def custom_compatibility_check(self, X, y, datafit, penalty): + if issparse(X): + raise ValueError( + "Sparse matrices are not yet support in `PDCD_WS` solver." + ) + @njit def _scores_primal(X, w, z, penalty, primal_steps, ws): diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index 901092d43..b2c469a2e 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -1,5 +1,6 @@ import numpy as np from numba import njit +from scipy.sparse import issparse from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration @@ -147,6 +148,11 @@ def custom_compatibility_check(self, X, y, datafit, penalty): check_group_compatible(datafit) check_group_compatible(penalty) + if issparse(X): + raise ValueError( + "Sparse matrices are not yet support in `GroupBCD` solver." + ) + @njit def _bcd_epoch(X, y, w, Xw, lipschitz, datafit, penalty, ws): diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 6329ed4a0..26667b7db 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -1,6 +1,8 @@ import numpy as np from numba import njit from numpy.linalg import norm +from scipy.sparse import issparse + from skglm.solvers.base import BaseSolver from skglm.utils.validation import check_group_compatible @@ -147,6 +149,11 @@ def custom_compatibility_check(self, X, y, datafit, penalty): check_group_compatible(datafit) check_group_compatible(penalty) + if issparse(X): + raise ValueError( + "Sparse matrices are not yet support in `GroupBCD` solver." + ) + @njit def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, From bab6277ce8c59bd6945bd15c7ddc49565b571b86 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 16:52:10 +0100 Subject: [PATCH 33/68] more on unittest --- skglm/tests/test_validation.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/skglm/tests/test_validation.py b/skglm/tests/test_validation.py index e92e3133a..829458e34 100644 --- a/skglm/tests/test_validation.py +++ b/skglm/tests/test_validation.py @@ -7,28 +7,27 @@ def test_datafit_penalty_solver_compatibility(): - X_sparse, y, _ = make_correlated_data(n_samples=3, n_features=5, X_density=0.5) - X_dense = X_sparse.todense() + X, y, _ = make_correlated_data(n_samples=3, n_features=5) with pytest.raises( AttributeError, match="Missing `raw_grad` and `raw_hessian`" ): ProxNewton()._validate( - X_dense, y, compiled_clone(Quadratic()), compiled_clone(L1(1.)) + X, y, compiled_clone(Quadratic()), compiled_clone(L1(1.)) ) with pytest.raises( AttributeError, match="Missing `get_global_lipschitz`" ): FISTA()._validate( - X_dense, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) + X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) ) with pytest.raises( AttributeError, match="Missing `get_global_lipschitz`" ): FISTA()._validate( - X_dense, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) + X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) ) From 7b795399e7de20ee0aec68a1b6636410d37b6f5e Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 16:53:22 +0100 Subject: [PATCH 34/68] fix what's new --- doc/changes/0.4.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/doc/changes/0.4.rst b/doc/changes/0.4.rst index 9a2350c49..2515a1de3 100644 --- a/doc/changes/0.4.rst +++ b/doc/changes/0.4.rst @@ -3,8 +3,6 @@ Version 0.4 (in progress) ------------------------- -- Add support for weights and positive coefficients to :ref:`MCPRegression Estimator ` (PR: :gh:`184`) -- Move solver specific computations from ``Datafit.initialize()`` to separate ``Datafit`` methods to ease ``Solver`` - ``Datafit`` compatibility check (PR: :gh:`192`) - Add support for weights and positive coefficients to :ref:`MCPRegression Estimator ` (PR: :gh:`184`) - Move solver specific computations from ``Datafit.initialize()`` to separate ``Datafit`` methods to ease ``Solver`` - ``Datafit`` compatibility check (PR: :gh:`192`) - Add :ref:`LogSumPenalty ` (PR: :gh:`#127`) From 118a0daa172dd295600e6aa8290d7c238ff4b752 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 16:56:22 +0100 Subject: [PATCH 35/68] use ``__call__`` instead of ``solve`` --- skglm/estimators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skglm/estimators.py b/skglm/estimators.py index 37c5f6ad8..f57503921 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -134,7 +134,7 @@ def _glm_fit(X, y, model, datafit, penalty, solver): "The size of the WeightedL1 penalty weights should be n_features, " "expected %i, got %i." % (X_.shape[1], len(penalty.weights))) - coefs, p_obj, kkt = solver.solve(X_, y, datafit_jit, penalty_jit, w, Xw) + coefs, p_obj, kkt = solver(X_, y, datafit_jit, penalty_jit, w, Xw) model.coef_, model.stop_crit_ = coefs[:n_features], kkt if y.ndim == 1: model.intercept_ = coefs[-1] if fit_intercept else 0. @@ -1350,7 +1350,7 @@ def fit(self, X, y): else: datafit.initialize_sparse(X.data, X.indptr, X.indices, y) - w, _, stop_crit = solver.solve(X, y, datafit, penalty) + w, _, stop_crit = solver(X, y, datafit, penalty) # save to attribute self.coef_ = w @@ -1482,7 +1482,7 @@ def fit(self, X, Y): self.max_iter, self.max_epochs, self.p0, tol=self.tol, ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept, warm_start=self.warm_start, verbose=self.verbose) - W, obj_out, kkt = solver.solve(X, Y, datafit_jit, penalty_jit) + W, obj_out, kkt = solver(X, Y, datafit_jit, penalty_jit) self.coef_ = W[:X.shape[1], :].T self.intercept_ = self.fit_intercept * W[-1, :] From 94f47c22367baeaaeaf43861522567d131fe19f7 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 16 Nov 2023 17:28:55 +0100 Subject: [PATCH 36/68] fix ``BaseSolver`` --- skglm/solvers/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 39bc55d04..565f27cf3 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -90,12 +90,12 @@ def __call__(self, X, y, datafit, penalty, w_init=None, Xw_init=None): >>> ... >>> coefs, obj_out, stop_crit = solver(X, y, datafit, penalty) """ - self._validate(datafit, penalty) - self.solve(X, y, datafit, penalty, w_init, Xw_init) + self._validate(X, y, datafit, penalty) + return self.solve(X, y, datafit, penalty, w_init, Xw_init) def _validate(self, X, y, datafit, penalty): - # execute both attributes checks and `custom_compatibility_check` + # execute both: attributes checks and `custom_compatibility_check` self.custom_compatibility_check(X, y, datafit, penalty) check_obj_solver_attr(datafit, self, self._datafit_required_attr) - check_obj_solver_attr(datafit, self, self._penalty_required_attr) + check_obj_solver_attr(penalty, self, self._penalty_required_attr) From 419df06b9aeffd95c136960158725325d5f7de9f Mon Sep 17 00:00:00 2001 From: Badr MOUFAD <65614794+Badr-MOUFAD@users.noreply.github.com> Date: Fri, 24 Nov 2023 09:59:26 +0100 Subject: [PATCH 37/68] Update skglm/solvers/group_bcd.py Co-authored-by: Quentin Bertrand --- skglm/solvers/group_bcd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index b2c469a2e..38678322d 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -150,7 +150,7 @@ def custom_compatibility_check(self, X, y, datafit, penalty): if issparse(X): raise ValueError( - "Sparse matrices are not yet support in `GroupBCD` solver." + "Sparse matrices are not yet supported in `GroupBCD` solver." ) From a26387bd2f3721e6f9a34b8720cc906458909e81 Mon Sep 17 00:00:00 2001 From: Badr MOUFAD <65614794+Badr-MOUFAD@users.noreply.github.com> Date: Fri, 24 Nov 2023 09:59:33 +0100 Subject: [PATCH 38/68] Update skglm/solvers/group_prox_newton.py Co-authored-by: Quentin Bertrand --- skglm/solvers/group_prox_newton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 26667b7db..e27540e23 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -151,7 +151,7 @@ def custom_compatibility_check(self, X, y, datafit, penalty): if issparse(X): raise ValueError( - "Sparse matrices are not yet support in `GroupBCD` solver." + "Sparse matrices are not yet supported in `GroupBCD` solver." ) From 677d0e31e9bb1fb3decf7574432a4d2362403b9e Mon Sep 17 00:00:00 2001 From: Badr MOUFAD <65614794+Badr-MOUFAD@users.noreply.github.com> Date: Fri, 24 Nov 2023 09:59:47 +0100 Subject: [PATCH 39/68] Update skglm/experimental/pdcd_ws.py Co-authored-by: Quentin Bertrand --- skglm/experimental/pdcd_ws.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index a6ec2b206..d0475c636 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -204,7 +204,7 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty, def custom_compatibility_check(self, X, y, datafit, penalty): if issparse(X): raise ValueError( - "Sparse matrices are not yet support in `PDCD_WS` solver." + "Sparse matrices are not yet supported in `PDCD_WS` solver." ) From 8f840bfd794202cf5e840e5fa82d57fd255f7f55 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Thu, 11 Apr 2024 11:26:28 +0200 Subject: [PATCH 40/68] chenges.rst --- doc/changes/0.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/changes/0.4.rst b/doc/changes/0.4.rst index 927f43b71..d2f726c75 100644 --- a/doc/changes/0.4.rst +++ b/doc/changes/0.4.rst @@ -4,7 +4,7 @@ Version 0.4 (in progress) ------------------------- - Add :ref:`GroupLasso Estimator ` (PR: :gh:`228`) - Add support and tutorial for positive coefficients to :ref:`Group Lasso Penalty ` (PR: :gh:`221`) - +- Check compatibility with datafit and penalty in solver (PR :gh:`137`) Version 0.3.1 (2023/12/21) -------------------------- From d8cc02247e52db9fc384cf5c8ff28d758c747957 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Thu, 11 Apr 2024 13:02:37 +0200 Subject: [PATCH 41/68] sparse matrices are now supported by GroupBCD --- skglm/solvers/group_bcd.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index 64d4300f3..f97b684df 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -168,11 +168,6 @@ def custom_compatibility_check(self, X, y, datafit, penalty): check_group_compatible(datafit) check_group_compatible(penalty) - if issparse(X): - raise ValueError( - "Sparse matrices are not yet supported in `GroupBCD` solver." - ) - @njit def _bcd_epoch(X, y, w, Xw, lipschitz, datafit, penalty, ws): From fa9ed35dfcfe598c9e9abad20c0f682d608dad73 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Fri, 24 May 2024 11:26:43 +0200 Subject: [PATCH 42/68] typo ``BaseSolver`` --- skglm/solvers/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 565f27cf3..5994968bd 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -8,10 +8,10 @@ class BaseSolver(ABC): Attributes ---------- _datafit_required_attr : list - List of attributes that must implemented in Datafit. + List of attributes that must be implemented in Datafit. _penalty_required_attr : list - List of attributes that must implemented in Penalty. + List of attributes that must be implemented in Penalty. Notes ----- From e687bc2ecaacfe920b9aaad3e33e1f0cbdbac683 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Fri, 24 May 2024 11:32:55 +0200 Subject: [PATCH 43/68] rm ``self`` in docs --- skglm/solvers/multitask_bcd.py | 2 +- skglm/solvers/prox_newton.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 097cabb0f..4e81f9bc8 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -245,7 +245,7 @@ def custom_compatibility_check(self, X, y, datafit, penalty): if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): raise AttributeError( "Penalty must implement `subdiff_distance` " - "to use self.ws_strategy='subdiff'." + "to use ws_strategy='subdiff'." ) diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index e542d0bf5..7b2cc5d67 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -201,7 +201,7 @@ def custom_compatibility_check(self, X, y, datafit, penalty): if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): raise AttributeError( "Penalty must implement `subdiff_distance` " - "to use self.ws_strategy='subdiff'." + "to use ws_strategy='subdiff'." ) From fd67df8d34c8f292e30d0584095c68f7a9a04f24 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Fri, 24 May 2024 12:16:26 +0200 Subject: [PATCH 44/68] more code-readable attribute error --- skglm/utils/validation.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index 645eb10c5..117fcedb4 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -76,12 +76,16 @@ def check_obj_solver_attr(obj, solver, required_attr, support_sparse=False): obj_name = name_matcher.search(str(obj.__class__)).group(1) solver_name = name_matcher.search(str(solver.__class__)).group(1) - raise AttributeError( - f"{obj_name} is not compatible with {solver_name}" - " with sparse data. " if support_sparse else ". " - f"It must implement {' and '.join(required_attr)}\n" - f"Missing {' and '.join(missing_attrs)}." - ) + if not support_sparse: + err_message = f"{obj_name} is not compatible with {solver_name}." + else: + err_message = (f"{obj_name} is not compatible with {solver_name}" + "with sparse data.") + + err_message += (f" It must implement {' and '.join(required_attr)}\n" + f"Missing {' and '.join(missing_attrs)}.") + + raise AttributeError(err_message) def _join_attrs_with_or(attrs, suffix=""): From 78295d84dcf30f74812141df2381504f719103c1 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Fri, 24 May 2024 12:30:23 +0200 Subject: [PATCH 45/68] rm data compilation in ``PDCD_WS`` --- skglm/experimental/pdcd_ws.py | 6 +----- skglm/experimental/tests/test_quantile_regression.py | 6 +++++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index d0475c636..2c4ef224e 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -92,11 +92,7 @@ def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None, self.tol = tol self.verbose = verbose - def solve(self, X, y, datafit_, penalty_, w_init=None, Xw_init=None): - # jit compile classes - datafit = compiled_clone(datafit_) - penalty = compiled_clone(penalty_) - + def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): n_samples, n_features = X.shape # init steps diff --git a/skglm/experimental/tests/test_quantile_regression.py b/skglm/experimental/tests/test_quantile_regression.py index 509b7079c..65e0c1e65 100644 --- a/skglm/experimental/tests/test_quantile_regression.py +++ b/skglm/experimental/tests/test_quantile_regression.py @@ -5,6 +5,7 @@ from skglm.penalties import L1 from skglm.experimental.pdcd_ws import PDCD_WS from skglm.experimental.quantile_regression import Pinball +from skglm.utils.jit_compilation import compiled_clone from skglm.utils.data import make_correlated_data from sklearn.linear_model import QuantileRegressor @@ -21,9 +22,12 @@ def test_PDCD_WS(quantile_level): alpha_max = norm(X.T @ (np.sign(y)/2 + (quantile_level - 0.5)), ord=np.inf) alpha = alpha_max / 5 + datafit = compiled_clone(Pinball(quantile_level)) + penalty = compiled_clone(L1(alpha)) + w = PDCD_WS( dual_init=np.sign(y)/2 + (quantile_level - 0.5) - ).solve(X, y, Pinball(quantile_level), L1(alpha))[0] + ).solve(X, y, datafit, penalty)[0] clf = QuantileRegressor( quantile=quantile_level, From 200337060eac142b7b0032bbb67b208f56ad8746 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Fri, 24 May 2024 12:32:31 +0200 Subject: [PATCH 46/68] rm unused imports --- skglm/experimental/pdcd_ws.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index 2c4ef224e..43949b1a6 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -6,7 +6,6 @@ from numba import njit from skglm.solvers import BaseSolver -from skglm.utils.jit_compilation import compiled_clone from sklearn.exceptions import ConvergenceWarning From a2aa8f55ed83f99663d34c1a36a7f8de987e190a Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Fri, 24 May 2024 12:52:59 +0200 Subject: [PATCH 47/68] fix test ``PDCD_WS`` --- skglm/experimental/tests/test_sqrt_lasso.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/skglm/experimental/tests/test_sqrt_lasso.py b/skglm/experimental/tests/test_sqrt_lasso.py index 91722abea..f5b044a86 100644 --- a/skglm/experimental/tests/test_sqrt_lasso.py +++ b/skglm/experimental/tests/test_sqrt_lasso.py @@ -7,6 +7,7 @@ from skglm.experimental.sqrt_lasso import (SqrtLasso, SqrtQuadratic, _chambolle_pock_sqrt) from skglm.experimental.pdcd_ws import PDCD_WS +from skglm.utils.jit_compilation import compiled_clone def test_alpha_max(): @@ -69,7 +70,10 @@ def test_PDCD_WS(with_dual_init): dual_init = y / norm(y) if with_dual_init else None - w = PDCD_WS(dual_init=dual_init).solve(X, y, SqrtQuadratic(), L1(alpha))[0] + datafit = compiled_clone(SqrtQuadratic()) + penalty = compiled_clone(L1(alpha)) + + w = PDCD_WS(dual_init=dual_init).solve(X, y, datafit, penalty)[0] clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y) np.testing.assert_allclose(clf.coef_, w, atol=1e-6) From 93c8dc09847b77dbcd9537982eac6146b0a1d715 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Thu, 30 May 2024 16:54:07 +0200 Subject: [PATCH 48/68] error msg --- skglm/utils/validation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index 117fcedb4..f04913356 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -77,12 +77,12 @@ def check_obj_solver_attr(obj, solver, required_attr, support_sparse=False): solver_name = name_matcher.search(str(solver.__class__)).group(1) if not support_sparse: - err_message = f"{obj_name} is not compatible with {solver_name}." + err_message = f"{obj_name} is not compatible with solver {solver_name}." else: - err_message = (f"{obj_name} is not compatible with {solver_name}" + err_message = (f"{obj_name} is not compatible with solver {solver_name} " "with sparse data.") - err_message += (f" It must implement {' and '.join(required_attr)}\n" + err_message += (f" It must implement {' and '.join(required_attr)}.\n" f"Missing {' and '.join(missing_attrs)}.") raise AttributeError(err_message) From 9c291a983d2aafbf3d4bb175be25829faa99a3f0 Mon Sep 17 00:00:00 2001 From: Badr MOUFAD <65614794+Badr-MOUFAD@users.noreply.github.com> Date: Thu, 30 May 2024 17:19:24 +0200 Subject: [PATCH 49/68] Update skglm/solvers/fista.py Co-authored-by: mathurinm --- skglm/solvers/fista.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index 33e2be46e..e4956f574 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -121,5 +121,5 @@ def custom_compatibility_check(self, X, y, datafit, penalty): if self.opt_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): raise AttributeError( "Penalty must implement `subdiff_distance` " - "to use self.opt_strategy='subdiff'." + "to use `opt_strategy='subdiff'` in Fista solver." ) From 65895ae65ec04f0b97a2c156270810bd71370618 Mon Sep 17 00:00:00 2001 From: Badr MOUFAD <65614794+Badr-MOUFAD@users.noreply.github.com> Date: Thu, 30 May 2024 17:19:34 +0200 Subject: [PATCH 50/68] Update skglm/solvers/gram_cd.py Co-authored-by: mathurinm --- skglm/solvers/gram_cd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index e3e3e69b0..42babad81 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -140,7 +140,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): def custom_compatibility_check(self, X, y, datafit): if not isinstance(datafit, Quadratic): raise AttributeError( - f"`GramCD` supports only `Quadratic` datafit, got {datafit}" + f"`GramCD` supports only `Quadratic` datafit, got {datafit}." ) From 2c3873d768b9b1ed55cc3510fd7b1ab3035c3548 Mon Sep 17 00:00:00 2001 From: Badr MOUFAD <65614794+Badr-MOUFAD@users.noreply.github.com> Date: Thu, 30 May 2024 17:20:02 +0200 Subject: [PATCH 51/68] Update skglm/solvers/anderson_cd.py Co-authored-by: mathurinm --- skglm/solvers/anderson_cd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 464dd2fb6..af7d734f9 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -283,7 +283,7 @@ def custom_compatibility_check(self, X, y, datafit, penalty): if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): raise AttributeError( "Penalty must implement `subdiff_distance` " - "to use self.ws_strategy='subdiff'." + "to use ws_strategy='subdiff' in solver AndersonCD." ) From ade9bad4a1ab1ce02132e84e7ccd12a89cc3a9c6 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Thu, 30 May 2024 17:21:50 +0200 Subject: [PATCH 52/68] number of arguments in GramCD custom_campatibility_check --- skglm/solvers/gram_cd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index e3e3e69b0..03fb71f92 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -137,7 +137,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out.append(p_obj) return w, np.array(p_objs_out), stop_crit - def custom_compatibility_check(self, X, y, datafit): + def custom_compatibility_check(self, X, y, datafit, penalty): if not isinstance(datafit, Quadratic): raise AttributeError( f"`GramCD` supports only `Quadratic` datafit, got {datafit}" From f3fae3eadcfa41c55eee9ad3f37d5deafa635394 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Thu, 30 May 2024 17:26:33 +0200 Subject: [PATCH 53/68] change version because this is a large change --- skglm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/__init__.py b/skglm/__init__.py index c134c98f2..d80de3c17 100644 --- a/skglm/__init__.py +++ b/skglm/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.3.2dev' +__version__ = '0.4dev' from skglm.estimators import ( # noqa F401 Lasso, WeightedLasso, ElasticNet, MCPRegression, MultiTaskLasso, LinearSVC, From c484aa42464310940eebd1a065f831b539bef188 Mon Sep 17 00:00:00 2001 From: Badr MOUFAD <65614794+Badr-MOUFAD@users.noreply.github.com> Date: Sun, 14 Jul 2024 22:21:02 +0200 Subject: [PATCH 54/68] Update skglm/solvers/gram_cd.py Co-authored-by: mathurinm --- skglm/solvers/gram_cd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 242dbff60..d0d6cef47 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -140,7 +140,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): def custom_compatibility_check(self, X, y, datafit, penalty): if not isinstance(datafit, Quadratic): raise AttributeError( - f"`GramCD` supports only `Quadratic` datafit, got {datafit}." + f"`GramCD` supports only `Quadratic` datafit, got {datafit.__class__.__name__}." ) From 3b1967270526937861e12d04ec60d511263ff124 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 14 Jul 2024 22:26:12 +0200 Subject: [PATCH 55/68] more on remarks --- skglm/solvers/gram_cd.py | 3 ++- skglm/solvers/prox_newton.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index d0d6cef47..9abf15d19 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -140,7 +140,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): def custom_compatibility_check(self, X, y, datafit, penalty): if not isinstance(datafit, Quadratic): raise AttributeError( - f"`GramCD` supports only `Quadratic` datafit, got {datafit.__class__.__name__}." + "`GramCD` supports only `Quadratic` datafit, " + f"got {datafit.__class__.__name__}." ) diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index 737390d73..3ac14d129 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -204,7 +204,7 @@ def custom_compatibility_check(self, X, y, datafit, penalty): if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): raise AttributeError( "Penalty must implement `subdiff_distance` " - "to use ws_strategy='subdiff'." + "to use ws_strategy='subdiff' in ProxNewton solver" ) From 352aa3f3fd8519cd37d0a3677730d226d1b123ac Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 14 Jul 2024 22:32:05 +0200 Subject: [PATCH 56/68] ``check_obj_solver_attr`` ---> ``check_attrs`` --- skglm/solvers/anderson_cd.py | 4 ++-- skglm/solvers/base.py | 6 +++--- skglm/solvers/fista.py | 4 ++-- skglm/solvers/lbfgs.py | 4 ++-- skglm/solvers/multitask_bcd.py | 4 ++-- skglm/utils/validation.py | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 67b244621..44170c8f3 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -7,7 +7,7 @@ ) from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration -from skglm.utils.validation import check_obj_solver_attr +from skglm.utils.validation import check_attrs class AndersonCD(BaseSolver): @@ -275,7 +275,7 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None, def custom_compatibility_check(self, X, y, datafit, penalty): # check datafit support sparse data - check_obj_solver_attr( + check_attrs( datafit, solver=self, required_attr=self._datafit_required_attr, support_sparse=sparse.issparse(X) diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 5994968bd..7919cb425 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -1,5 +1,5 @@ from abc import abstractmethod, ABC -from skglm.utils.validation import check_obj_solver_attr +from skglm.utils.validation import check_attrs class BaseSolver(ABC): @@ -97,5 +97,5 @@ def _validate(self, X, y, datafit, penalty): # execute both: attributes checks and `custom_compatibility_check` self.custom_compatibility_check(X, y, datafit, penalty) - check_obj_solver_attr(datafit, self, self._datafit_required_attr) - check_obj_solver_attr(penalty, self, self._penalty_required_attr) + check_attrs(datafit, self, self._datafit_required_attr) + check_attrs(penalty, self, self._penalty_required_attr) diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index e4956f574..f75ed1794 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -3,7 +3,7 @@ from skglm.solvers.base import BaseSolver from skglm.solvers.common import construct_grad, construct_grad_sparse from skglm.utils.prox_funcs import _prox_vec -from skglm.utils.validation import check_obj_solver_attr +from skglm.utils.validation import check_attrs class FISTA(BaseSolver): @@ -111,7 +111,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): def custom_compatibility_check(self, X, y, datafit, penalty): # check datafit support sparse data - check_obj_solver_attr( + check_attrs( datafit, solver=self, required_attr=self._datafit_required_attr, support_sparse=issparse(X) diff --git a/skglm/solvers/lbfgs.py b/skglm/solvers/lbfgs.py index df49e3cfe..527abbc7e 100644 --- a/skglm/solvers/lbfgs.py +++ b/skglm/solvers/lbfgs.py @@ -7,7 +7,7 @@ from scipy.sparse import issparse from skglm.solvers import BaseSolver -from skglm.utils.validation import check_obj_solver_attr +from skglm.utils.validation import check_attrs class LBFGS(BaseSolver): @@ -109,7 +109,7 @@ def callback_post_iter(w_k): def custom_compatibility_check(self, X, y, datafit, penalty): # check datafit support sparse data - check_obj_solver_attr( + check_attrs( datafit, solver=self, required_attr=self._datafit_required_attr, support_sparse=issparse(X) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index b25861ebe..93d1e3c25 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -4,7 +4,7 @@ from numpy.linalg import norm from sklearn.utils import check_array from skglm.solvers.base import BaseSolver -from skglm.utils.validation import check_obj_solver_attr +from skglm.utils.validation import check_attrs class MultiTaskBCD(BaseSolver): @@ -237,7 +237,7 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) def custom_compatibility_check(self, X, y, datafit, penalty): # check datafit support sparse data - check_obj_solver_attr( + check_attrs( datafit, solver=self, required_attr=self._datafit_required_attr, support_sparse=sparse.issparse(X) diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index f04913356..9b9b5af0c 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -30,7 +30,7 @@ def check_group_compatible(obj): ) -def check_obj_solver_attr(obj, solver, required_attr, support_sparse=False): +def check_attrs(obj, solver, required_attr, support_sparse=False): """Check whether datafit or penalty is compatible with solver. Parameters From 6c3ddbc5a128bda4a5513e67e32d75e8052c479c Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 14 Jul 2024 22:42:12 +0200 Subject: [PATCH 57/68] implement ``_solve`` and ``solve`` --- skglm/estimators.py | 6 +++--- skglm/solvers/anderson_cd.py | 2 +- skglm/solvers/base.py | 15 +++++++++------ skglm/solvers/fista.py | 2 +- skglm/solvers/gram_cd.py | 2 +- skglm/solvers/group_bcd.py | 2 +- skglm/solvers/group_prox_newton.py | 2 +- skglm/solvers/lbfgs.py | 2 +- skglm/solvers/multitask_bcd.py | 2 +- skglm/solvers/prox_newton.py | 2 +- 10 files changed, 20 insertions(+), 17 deletions(-) diff --git a/skglm/estimators.py b/skglm/estimators.py index 5d7006452..cc488a422 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -136,7 +136,7 @@ def _glm_fit(X, y, model, datafit, penalty, solver): "The size of the WeightedL1 penalty weights should be n_features, " "expected %i, got %i." % (X_.shape[1], len(penalty.weights))) - coefs, p_obj, kkt = solver(X_, y, datafit_jit, penalty_jit, w, Xw) + coefs, p_obj, kkt = solver.solve(X_, y, datafit_jit, penalty_jit, w, Xw) model.coef_, model.stop_crit_ = coefs[:n_features], kkt if y.ndim == 1: model.intercept_ = coefs[-1] if fit_intercept else 0. @@ -1352,7 +1352,7 @@ def fit(self, X, y): else: datafit.initialize_sparse(X.data, X.indptr, X.indices, y) - w, _, stop_crit = solver(X, y, datafit, penalty) + w, _, stop_crit = solver.solve(X, y, datafit, penalty) # save to attribute self.coef_ = w @@ -1484,7 +1484,7 @@ def fit(self, X, Y): self.max_iter, self.max_epochs, self.p0, tol=self.tol, ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept, warm_start=self.warm_start, verbose=self.verbose) - W, obj_out, kkt = solver(X, Y, datafit_jit, penalty_jit) + W, obj_out, kkt = solver.solve(X, Y, datafit_jit, penalty_jit) self.coef_ = W[:X.shape[1], :].T self.intercept_ = self.fit_intercept * W[-1, :] diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 44170c8f3..fbf5b1029 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -63,7 +63,7 @@ def __init__(self, max_iter=50, max_epochs=50_000, p0=10, self.warm_start = warm_start self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if self.ws_strategy not in ("subdiff", "fixpoint"): raise ValueError( 'Unsupported value for self.ws_strategy:', self.ws_strategy) diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 7919cb425..9f52feaa9 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -23,7 +23,7 @@ class BaseSolver(ABC): _penalty_required_attr: list @abstractmethod - def solve(self, X, y, datafit, penalty, w_init, Xw_init): + def _solve(self, X, y, datafit, penalty, w_init, Xw_init): """Solve an optimization problem. Parameters @@ -79,19 +79,22 @@ def custom_compatibility_check(self, X, y, datafit, penalty): Penalty. """ - def __call__(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None, + *, run_checks=True): """Solve the optimization problem after validating its compatibility. - A proxy of ``solve`` method that implicitly ensures the compatibility + A proxy of ``_solve`` method that implicitly ensures the compatibility of ``datafit`` and ``penalty`` with the solver. Examples -------- >>> ... - >>> coefs, obj_out, stop_crit = solver(X, y, datafit, penalty) + >>> coefs, obj_out, stop_crit = solver.solve(X, y, datafit, penalty) """ - self._validate(X, y, datafit, penalty) - return self.solve(X, y, datafit, penalty, w_init, Xw_init) + if run_checks: + self._validate(X, y, datafit, penalty) + + return self._solve(X, y, datafit, penalty, w_init, Xw_init) def _validate(self, X, y, datafit, penalty): # execute both: attributes checks and `custom_compatibility_check` diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index f75ed1794..a0e484753 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -39,7 +39,7 @@ def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0): self.fit_intercept = False # needed to be passed to GeneralizedLinearEstimator self.warm_start = False - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out = [] n_samples, n_features = X.shape all_features = np.arange(n_features) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 9abf15d19..70e3a3029 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -64,7 +64,7 @@ def __init__(self, max_iter=100, use_acc=False, greedy_cd=True, tol=1e-4, self.warm_start = warm_start self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): # we don't pass Xw_init as the solver uses Gram updates # to keep the gradient up-to-date instead of Xw n_samples, n_features = X.shape diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index 62e68b1a9..595ddd2b8 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -52,7 +52,7 @@ def __init__( self.ws_strategy = ws_strategy self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if self.ws_strategy not in ("subdiff", "fixpoint"): raise ValueError( 'Unsupported value for self.ws_strategy:', self.ws_strategy) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index e27540e23..80dca472a 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -57,7 +57,7 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, self.warm_start = warm_start self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): fit_intercept = self.fit_intercept n_samples, n_features = X.shape grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices diff --git a/skglm/solvers/lbfgs.py b/skglm/solvers/lbfgs.py index 527abbc7e..01a35d499 100644 --- a/skglm/solvers/lbfgs.py +++ b/skglm/solvers/lbfgs.py @@ -36,7 +36,7 @@ def __init__(self, max_iter=50, tol=1e-4, verbose=False): self.tol = tol self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): def objective(w): Xw = X @ w diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 93d1e3c25..70b99940f 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -26,7 +26,7 @@ def __init__(self, max_iter=100, max_epochs=50_000, p0=10, tol=1e-6, self.warm_start = warm_start self.verbose = verbose - def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None): + def _solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None): n_samples, n_features = X.shape n_tasks = Y.shape[1] pen = penalty.is_penalized(n_features) diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index 3ac14d129..a25cf88d4 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -67,7 +67,7 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, self.warm_start = warm_start self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if self.ws_strategy not in ("subdiff", "fixpoint"): raise ValueError("ws_strategy must be `subdiff` or `fixpoint`, " f"got {self.ws_strategy}.") From de1f9aea3f8cda07bdd5f8905f3aef8ec557c69d Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 14 Jul 2024 22:51:39 +0200 Subject: [PATCH 58/68] forgotten `PDCD_WS` solver --- skglm/experimental/pdcd_ws.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index 43949b1a6..306b4947f 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -91,7 +91,7 @@ def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None, self.tol = tol self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): n_samples, n_features = X.shape # init steps From 1c1b2588e490e70d2c78c89459dc7e2d80e93546 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 14 Jul 2024 22:52:03 +0200 Subject: [PATCH 59/68] fix `GramCD` checks --- skglm/solvers/gram_cd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 70e3a3029..ccb5b8e57 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -138,10 +138,10 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.array(p_objs_out), stop_crit def custom_compatibility_check(self, X, y, datafit, penalty): - if not isinstance(datafit, Quadratic): + if datafit is not None: raise AttributeError( - "`GramCD` supports only `Quadratic` datafit, " - f"got {datafit.__class__.__name__}." + "`GramCD` supports only `Quadratic` datafit and fits it implicitly, " + f"argument `datafit` must be `None`, got {datafit.__class__.__name__}." ) From bf1994483b556647ab683a974b6d093b0846856e Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 14 Jul 2024 23:16:40 +0200 Subject: [PATCH 60/68] linter happy & fix validation --- skglm/solvers/gram_cd.py | 1 - skglm/tests/test_validation.py | 7 ++++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index ccb5b8e57..0b50afa1f 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -3,7 +3,6 @@ from numba import njit from scipy.sparse import issparse -from skglm.datafits import Quadratic from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration diff --git a/skglm/tests/test_validation.py b/skglm/tests/test_validation.py index 829458e34..25c23c9a7 100644 --- a/skglm/tests/test_validation.py +++ b/skglm/tests/test_validation.py @@ -1,9 +1,10 @@ import pytest -from skglm.datafits import Quadratic, Poisson + from skglm.penalties import L1 +from skglm.datafits import Poisson, Huber from skglm.solvers import FISTA, ProxNewton -from skglm.utils.jit_compilation import compiled_clone from skglm.utils.data import make_correlated_data +from skglm.utils.jit_compilation import compiled_clone def test_datafit_penalty_solver_compatibility(): @@ -13,7 +14,7 @@ def test_datafit_penalty_solver_compatibility(): AttributeError, match="Missing `raw_grad` and `raw_hessian`" ): ProxNewton()._validate( - X, y, compiled_clone(Quadratic()), compiled_clone(L1(1.)) + X, y, compiled_clone(Huber(1.)), compiled_clone(L1(1.)) ) with pytest.raises( From 290879b5d16224fa09baaa3f7c65aee61c8dc200 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 14 Jul 2024 23:52:34 +0200 Subject: [PATCH 61/68] cleanups and comments --- skglm/solvers/base.py | 5 ++++- skglm/solvers/group_bcd.py | 2 -- skglm/utils/validation.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 9f52feaa9..9d488e18e 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -78,6 +78,7 @@ def custom_compatibility_check(self, X, y, datafit, penalty): penalty : instance of BasePenalty Penalty. """ + pass def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None, *, run_checks=True): @@ -97,8 +98,10 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None, return self._solve(X, y, datafit, penalty, w_init, Xw_init) def _validate(self, X, y, datafit, penalty): - # execute both: attributes checks and `custom_compatibility_check` + # execute: `custom_compatibility_check` then check attributes self.custom_compatibility_check(X, y, datafit, penalty) + # do not check for sparse support here, make the check at the solver level + # some solvers like ProxNewton don't require methods for sparse support check_attrs(datafit, self, self._datafit_required_attr) check_attrs(penalty, self, self._penalty_required_attr) diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index 595ddd2b8..bc0bf2311 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -56,8 +56,6 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if self.ws_strategy not in ("subdiff", "fixpoint"): raise ValueError( 'Unsupported value for self.ws_strategy:', self.ws_strategy) - check_group_compatible(datafit) - check_group_compatible(penalty) n_samples, n_features = X.shape n_groups = len(penalty.grp_ptr) - 1 diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index 9b9b5af0c..14cd8a9ce 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -56,7 +56,7 @@ def check_attrs(obj, solver, required_attr, support_sparse=False): missing_attrs = [] suffix = SPARSE_SUFFIX if support_sparse else "" - # if `attr` is a list check that at least one of them + # if `attr` is a list, check that at least one of them # is within `obj` attributes for attr in required_attr: attributes = attr if not isinstance(attr, str) else (attr,) From 4f5781eccc75210eddbad90f8b20fe63650899f3 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 15 Jul 2024 10:32:40 +0200 Subject: [PATCH 62/68] more on docs --- skglm/solvers/base.py | 9 +++++++++ skglm/utils/validation.py | 6 +++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 9d488e18e..b64011927 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -17,6 +17,15 @@ class BaseSolver(ABC): ----- For required attributes, if an attribute is given as a list of attributes it means at least one of them should be implemented. + For instance, if + + _datafit_required_attr = ( + "get_global_lipschitz", + ("gradient", "gradient_scalar") + ) + + it mean datafit must implement the methods ``get_global_lipschitz`` + and (``gradient`` or ``gradient_scaler``). """ _datafit_required_attr: list diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index 14cd8a9ce..264ad4bb7 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -49,9 +49,9 @@ def check_attrs(obj, solver, required_attr, support_sparse=False): Raises ------ - AttributeError - if any of the attribute in ``required_attr`` is missing - from ``obj`` attributes. + AttributeError + if any of the attribute in ``required_attr`` is missing + from ``obj`` attributes. """ missing_attrs = [] suffix = SPARSE_SUFFIX if support_sparse else "" From 83f10d5dae131cac6229a46e33deda401b1ece29 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 15 Jul 2024 10:35:45 +0200 Subject: [PATCH 63/68] ``custom_compatibility_check`` ---> ``custom_checks`` --- skglm/experimental/pdcd_ws.py | 2 +- skglm/solvers/anderson_cd.py | 2 +- skglm/solvers/base.py | 6 +++--- skglm/solvers/fista.py | 2 +- skglm/solvers/gram_cd.py | 2 +- skglm/solvers/group_bcd.py | 2 +- skglm/solvers/group_prox_newton.py | 2 +- skglm/solvers/lbfgs.py | 2 +- skglm/solvers/multitask_bcd.py | 2 +- skglm/solvers/prox_newton.py | 2 +- 10 files changed, 12 insertions(+), 12 deletions(-) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index 306b4947f..81e72da8c 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -196,7 +196,7 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty, if stop_crit_in <= tol_in: break - def custom_compatibility_check(self, X, y, datafit, penalty): + def custom_checks(self, X, y, datafit, penalty): if issparse(X): raise ValueError( "Sparse matrices are not yet supported in `PDCD_WS` solver." diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index fbf5b1029..d39a24086 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -273,7 +273,7 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None, results += (n_iters,) return results - def custom_compatibility_check(self, X, y, datafit, penalty): + def custom_checks(self, X, y, datafit, penalty): # check datafit support sparse data check_attrs( datafit, solver=self, diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index b64011927..06a08a690 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -67,7 +67,7 @@ def _solve(self, X, y, datafit, penalty, w_init, Xw_init): Value of stopping criterion at convergence. """ - def custom_compatibility_check(self, X, y, datafit, penalty): + def custom_checks(self, X, y, datafit, penalty): """Ensure the solver is suited for the `datafit` + `penalty` problem. This method includes extra checks to perform @@ -107,8 +107,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None, return self._solve(X, y, datafit, penalty, w_init, Xw_init) def _validate(self, X, y, datafit, penalty): - # execute: `custom_compatibility_check` then check attributes - self.custom_compatibility_check(X, y, datafit, penalty) + # execute: `custom_checks` then check attributes + self.custom_checks(X, y, datafit, penalty) # do not check for sparse support here, make the check at the solver level # some solvers like ProxNewton don't require methods for sparse support diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index a0e484753..e0933a111 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -109,7 +109,7 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): break return w, np.array(p_objs_out), stop_crit - def custom_compatibility_check(self, X, y, datafit, penalty): + def custom_checks(self, X, y, datafit, penalty): # check datafit support sparse data check_attrs( datafit, solver=self, diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 0b50afa1f..9ecf42bfb 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -136,7 +136,7 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out.append(p_obj) return w, np.array(p_objs_out), stop_crit - def custom_compatibility_check(self, X, y, datafit, penalty): + def custom_checks(self, X, y, datafit, penalty): if datafit is not None: raise AttributeError( "`GramCD` supports only `Quadratic` datafit and fits it implicitly, " diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index bc0bf2311..07e90a812 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -182,7 +182,7 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, p_objs_out, stop_crit - def custom_compatibility_check(self, X, y, datafit, penalty): + def custom_checks(self, X, y, datafit, penalty): check_group_compatible(datafit) check_group_compatible(penalty) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 80dca472a..f526eaa48 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -145,7 +145,7 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out.append(p_obj) return w, np.asarray(p_objs_out), stop_crit - def custom_compatibility_check(self, X, y, datafit, penalty): + def custom_checks(self, X, y, datafit, penalty): check_group_compatible(datafit) check_group_compatible(penalty) diff --git a/skglm/solvers/lbfgs.py b/skglm/solvers/lbfgs.py index 01a35d499..438c8b97b 100644 --- a/skglm/solvers/lbfgs.py +++ b/skglm/solvers/lbfgs.py @@ -107,7 +107,7 @@ def callback_post_iter(w_k): return w, np.asarray(p_objs_out), stop_crit - def custom_compatibility_check(self, X, y, datafit, penalty): + def custom_checks(self, X, y, datafit, penalty): # check datafit support sparse data check_attrs( datafit, solver=self, diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 70b99940f..5a8dfa5e6 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -235,7 +235,7 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) return results - def custom_compatibility_check(self, X, y, datafit, penalty): + def custom_checks(self, X, y, datafit, penalty): # check datafit support sparse data check_attrs( datafit, solver=self, diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index a25cf88d4..76867c7d8 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -199,7 +199,7 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): ) return w, np.asarray(p_objs_out), stop_crit - def custom_compatibility_check(self, X, y, datafit, penalty): + def custom_checks(self, X, y, datafit, penalty): # ws strategy if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): raise AttributeError( From e2bcdcc079da7e4c6aab4107d79d0e9dddd378d6 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 15 Jul 2024 11:45:26 +0200 Subject: [PATCH 64/68] rm unimplemented methods --- skglm/datafits/single_task.py | 3 --- skglm/penalties/block_separable.py | 11 ----------- 2 files changed, 14 deletions(-) diff --git a/skglm/datafits/single_task.py b/skglm/datafits/single_task.py index 5750ea295..1ccb218aa 100644 --- a/skglm/datafits/single_task.py +++ b/skglm/datafits/single_task.py @@ -661,9 +661,6 @@ def value(self, y, w, Xw): def gradient_scalar(self, X, y, w, Xw, j): return X[:, j] @ (1 - y * np.exp(-Xw)) / len(y) - def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j): - pass - def intercept_update_step(self, y, Xw): return np.sum(self.raw_grad(y, Xw)) diff --git a/skglm/penalties/block_separable.py b/skglm/penalties/block_separable.py index 47161080e..091392601 100644 --- a/skglm/penalties/block_separable.py +++ b/skglm/penalties/block_separable.py @@ -458,17 +458,6 @@ def prox_1group(self, value, stepsize, g): res = ST_vec(value, self.alpha * stepsize * self.weights_features[g]) return BST(res, self.alpha * stepsize * self.weights_groups[g]) - def subdiff_distance(self, w, grad_ws, ws): - """Compute distance to the subdifferential at ``w`` of negative gradient. - - Refer to :ref:`subdiff_positive_group_lasso` for details of the derivation. - - Note: - ---- - ``grad_ws`` is a stacked array of gradients ``[grad_ws_1, grad_ws_2, ...]``. - """ - raise NotImplementedError("Too hard for now") - def is_penalized(self, n_groups): return np.ones(n_groups, dtype=np.bool_) From 3448e52dd71559d029bddfb67363c0cd55eecc4c Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 15 Jul 2024 11:45:58 +0200 Subject: [PATCH 65/68] correct error message --- skglm/solvers/group_bcd.py | 16 +++++++++++++++- skglm/solvers/group_prox_newton.py | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index 07e90a812..259f61ff0 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -4,7 +4,7 @@ from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration -from skglm.utils.validation import check_group_compatible +from skglm.utils.validation import check_group_compatible, check_attrs from skglm.solvers.common import dist_fix_point_bcd @@ -186,6 +186,20 @@ def custom_checks(self, X, y, datafit, penalty): check_group_compatible(datafit) check_group_compatible(penalty) + # check datafit support sparse data + check_attrs( + datafit, solver=self, + required_attr=self._datafit_required_attr, + support_sparse=issparse(X) + ) + + # ws strategy + if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): + raise AttributeError( + "Penalty must implement `subdiff_distance` " + "to use ws_strategy='subdiff'." + ) + @njit def _bcd_epoch(X, y, w, Xw, lipschitz, datafit, penalty, ws): diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index f526eaa48..1492651c3 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -151,7 +151,7 @@ def custom_checks(self, X, y, datafit, penalty): if issparse(X): raise ValueError( - "Sparse matrices are not yet supported in `GroupBCD` solver." + "Sparse matrices are not yet supported in `GroupProxNewton` solver." ) From b7b46277d90cfd3d65a2587749e68fa3d42d923e Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 15 Jul 2024 11:46:22 +0200 Subject: [PATCH 66/68] more on validation unit tests --- skglm/tests/test_validation.py | 66 ++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 6 deletions(-) diff --git a/skglm/tests/test_validation.py b/skglm/tests/test_validation.py index 25c23c9a7..7e998bfb8 100644 --- a/skglm/tests/test_validation.py +++ b/skglm/tests/test_validation.py @@ -1,35 +1,89 @@ import pytest +import numpy as np +from scipy import sparse -from skglm.penalties import L1 -from skglm.datafits import Poisson, Huber -from skglm.solvers import FISTA, ProxNewton +from skglm.penalties import L1, WeightedL1GroupL2, WeightedGroupL2 +from skglm.datafits import Poisson, Huber, QuadraticGroup, LogisticGroup +from skglm.solvers import FISTA, ProxNewton, GroupBCD, GramCD, GroupProxNewton + +from skglm.utils.data import grp_converter from skglm.utils.data import make_correlated_data from skglm.utils.jit_compilation import compiled_clone def test_datafit_penalty_solver_compatibility(): - X, y, _ = make_correlated_data(n_samples=3, n_features=5) + grp_size, n_features = 3, 9 + n_samples = 10 + X, y, _ = make_correlated_data(n_samples, n_features) + X_sparse = sparse.csc_array(X) + + n_groups = n_features // grp_size + weights_groups = np.ones(n_groups) + weights_features = np.ones(n_features) + grp_indices, grp_ptr = grp_converter(grp_size, n_features) + # basic compatibility checks with pytest.raises( AttributeError, match="Missing `raw_grad` and `raw_hessian`" ): ProxNewton()._validate( X, y, compiled_clone(Huber(1.)), compiled_clone(L1(1.)) ) - with pytest.raises( AttributeError, match="Missing `get_global_lipschitz`" ): FISTA()._validate( X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) ) - with pytest.raises( AttributeError, match="Missing `get_global_lipschitz`" ): FISTA()._validate( X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) ) + # check Gram Solver + with pytest.raises( + AttributeError, match="`GramCD` supports only `Quadratic` datafit" + ): + GramCD()._validate( + X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) + ) + # check working set strategy subdiff + with pytest.raises( + AttributeError, match="Penalty must implement `subdiff_distance`" + ): + GroupBCD()._validate( + X, y, + datafit=compiled_clone(QuadraticGroup(grp_ptr, grp_indices)), + penalty=compiled_clone( + WeightedL1GroupL2( + 1., weights_groups, weights_features, grp_ptr, grp_indices) + ) + ) + # checks for sparsity + with pytest.raises( + ValueError, + match="Sparse matrices are not yet supported in `GroupProxNewton` solver." + ): + GroupProxNewton()._validate( + X_sparse, y, + datafit=compiled_clone(QuadraticGroup(grp_ptr, grp_indices)), + penalty=compiled_clone( + WeightedL1GroupL2( + 1., weights_groups, weights_features, grp_ptr, grp_indices) + ) + ) + with pytest.raises( + AttributeError, + match="LogisticGroup is not compatible with solver GroupBCD with sparse data." + ): + GroupBCD()._validate( + X_sparse, y, + datafit=compiled_clone(LogisticGroup(grp_ptr, grp_indices)), + penalty=compiled_clone( + WeightedGroupL2(1., weights_groups, grp_ptr, grp_indices) + ) + ) if __name__ == "__main__": From d626e367bf907a199f27963b329a455667e97c76 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 15 Jul 2024 11:50:30 +0200 Subject: [PATCH 67/68] update what's new --- doc/changes/0.4.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/changes/0.4.rst b/doc/changes/0.4.rst index 05d924c3e..904401fd9 100644 --- a/doc/changes/0.4.rst +++ b/doc/changes/0.4.rst @@ -15,4 +15,3 @@ Version 0.3.1 (2023/12/21) - Add :ref:`LogSumPenalty ` (PR: :gh:`#127`) - Remove abstract methods in ``BaseDatafit`` and ``BasePenalty`` to make solver/penalty/datafit compatibility check easier (PR :gh:`#205`) - Add fixed-point distance to build working sets in :ref:`ProxNewton ` solver (:gh:`138`) -- Check compatibility between ``datafit + penalty`` and solver (PR :gh:`137`) From a02262b4e162c4a85ccd45529f777ebf4c4a9135 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 15 Jul 2024 12:00:24 +0200 Subject: [PATCH 68/68] handle `subdiff_distance` in custom checks --- skglm/solvers/group_bcd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index 259f61ff0..c7b515dad 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -38,7 +38,7 @@ class GroupBCD(BaseSolver): """ _datafit_required_attr = ("get_lipschitz", "gradient_g") - _penalty_required_attr = ("prox_1group", "subdiff_distance") + _penalty_required_attr = ("prox_1group",) def __init__( self, max_iter=1000, max_epochs=100, p0=10, tol=1e-4, fit_intercept=False,