Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH - check datafit + penalty compatibility with solver #137

Merged
merged 77 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
0154215
initial commit
PABannier Dec 10, 2022
3c70b85
delegated to solve
PABannier Jan 7, 2023
ee0d29c
call `solver.validate`
PABannier Jan 7, 2023
8831cd2
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Oct 18, 2023
9abd149
add validate method to solvers
Badr-MOUFAD Oct 18, 2023
d20229a
implem validation logic
Badr-MOUFAD Oct 18, 2023
310c572
implem attribute validation for solvers
Badr-MOUFAD Oct 18, 2023
1872e4e
validation ``PDCD_WS``
Badr-MOUFAD Oct 18, 2023
cd1ba1c
fix trailing spaces
Badr-MOUFAD Oct 18, 2023
4c9b4d4
add docs to ``check_obj_solver_compatibility``
Badr-MOUFAD Oct 18, 2023
c3d01c4
add validation glm_fit
Badr-MOUFAD Oct 18, 2023
eacea14
fix Error logs
Badr-MOUFAD Oct 18, 2023
573cb78
fix prox solvers attribute names
Badr-MOUFAD Oct 18, 2023
8dedf18
add ``initialize`` to required attributes
Badr-MOUFAD Oct 18, 2023
19fbe63
add change to what's new
Badr-MOUFAD Oct 18, 2023
3b65445
formatting & Fista validation
Badr-MOUFAD Oct 18, 2023
ef64f82
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Oct 26, 2023
ddc67b6
get names obj and solver in ``check_obj_solver_attr_compatibility``
Badr-MOUFAD Oct 26, 2023
5d75070
follow up changes
Badr-MOUFAD Oct 26, 2023
5d1b618
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Nov 2, 2023
ce118ca
specify required attributes in solvers
Badr-MOUFAD Nov 2, 2023
f38d12e
add ``*_required_attr`` in ``BaseSolver``
Badr-MOUFAD Nov 2, 2023
d1aeb4c
pass in solver to ``check_obj_solver_attr_compatibility``
Badr-MOUFAD Nov 2, 2023
b4e2d8e
handle solvers that supports ``ws_strategy='subdiff_distance'``
Badr-MOUFAD Nov 2, 2023
92f5d3b
revert solver with subdiff check
Badr-MOUFAD Nov 2, 2023
af7ae2c
unittest validation && abc fixes
Badr-MOUFAD Nov 2, 2023
bbfa438
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Nov 16, 2023
850f04f
implicit validation in ``__call__``
Badr-MOUFAD Nov 16, 2023
84a2cf8
validation logic revisited
Badr-MOUFAD Nov 16, 2023
05e5cff
BaseSolver as abstract class
Badr-MOUFAD Nov 16, 2023
4aefc8c
add required attributes
Badr-MOUFAD Nov 16, 2023
626fe77
pass & cleanups
Badr-MOUFAD Nov 16, 2023
b50ae91
unittest && docs
Badr-MOUFAD Nov 16, 2023
e81409f
validations for sparse data support
Badr-MOUFAD Nov 16, 2023
39fa7b9
validate ``subdiff`` and ``fixpoint``
Badr-MOUFAD Nov 16, 2023
40117a4
sparse support in group solvers
Badr-MOUFAD Nov 16, 2023
bab6277
more on unittest
Badr-MOUFAD Nov 16, 2023
7b79539
fix what's new
Badr-MOUFAD Nov 16, 2023
118a0da
use ``__call__`` instead of ``solve``
Badr-MOUFAD Nov 16, 2023
94f47c2
fix ``BaseSolver``
Badr-MOUFAD Nov 16, 2023
419df06
Update skglm/solvers/group_bcd.py
Badr-MOUFAD Nov 24, 2023
a26387b
Update skglm/solvers/group_prox_newton.py
Badr-MOUFAD Nov 24, 2023
677d0e3
Update skglm/experimental/pdcd_ws.py
Badr-MOUFAD Nov 24, 2023
5fb3994
Merge branch 'main' of github.com:scikit-learn-contrib/skglm into sol…
mathurinm Apr 11, 2024
8f840bf
chenges.rst
mathurinm Apr 11, 2024
d8cc022
sparse matrices are now supported by GroupBCD
mathurinm Apr 11, 2024
8e41a5d
Merge branch 'main' into solver_dispatcher
Badr-MOUFAD May 24, 2024
fa9ed35
typo ``BaseSolver``
Badr-MOUFAD May 24, 2024
e687bc2
rm ``self`` in docs
Badr-MOUFAD May 24, 2024
fd67df8
more code-readable attribute error
Badr-MOUFAD May 24, 2024
78295d8
rm data compilation in ``PDCD_WS``
Badr-MOUFAD May 24, 2024
2003370
rm unused imports
Badr-MOUFAD May 24, 2024
a2aa8f5
fix test ``PDCD_WS``
Badr-MOUFAD May 24, 2024
93c8dc0
error msg
mathurinm May 30, 2024
9c291a9
Update skglm/solvers/fista.py
Badr-MOUFAD May 30, 2024
65895ae
Update skglm/solvers/gram_cd.py
Badr-MOUFAD May 30, 2024
2c3873d
Update skglm/solvers/anderson_cd.py
Badr-MOUFAD May 30, 2024
ade9bad
number of arguments in GramCD custom_campatibility_check
mathurinm May 30, 2024
49661d7
Merge branch 'solver_dispatcher' of github.com:PABannier/skglm into s…
mathurinm May 30, 2024
f3fae3e
change version because this is a large change
mathurinm May 30, 2024
ccd7cd7
Merge branch 'main' into solver_dispatcher
Badr-MOUFAD Jul 14, 2024
bb15aea
Merge remote-tracking branch 'PAB/solver_dispatcher' into solver_disp…
Badr-MOUFAD Jul 14, 2024
c484aa4
Update skglm/solvers/gram_cd.py
Badr-MOUFAD Jul 14, 2024
3b19672
more on remarks
Badr-MOUFAD Jul 14, 2024
352aa3f
``check_obj_solver_attr`` ---> ``check_attrs``
Badr-MOUFAD Jul 14, 2024
6c3ddbc
implement ``_solve`` and ``solve``
Badr-MOUFAD Jul 14, 2024
de1f9ae
forgotten `PDCD_WS` solver
Badr-MOUFAD Jul 14, 2024
1c1b258
fix `GramCD` checks
Badr-MOUFAD Jul 14, 2024
bf19944
linter happy & fix validation
Badr-MOUFAD Jul 14, 2024
290879b
cleanups and comments
Badr-MOUFAD Jul 14, 2024
4f5781e
more on docs
Badr-MOUFAD Jul 15, 2024
83f10d5
``custom_compatibility_check`` ---> ``custom_checks``
Badr-MOUFAD Jul 15, 2024
e2bcdcc
rm unimplemented methods
Badr-MOUFAD Jul 15, 2024
3448e52
correct error message
Badr-MOUFAD Jul 15, 2024
b7b4627
more on validation unit tests
Badr-MOUFAD Jul 15, 2024
d626e36
update what's new
Badr-MOUFAD Jul 15, 2024
a02262b
handle `subdiff_distance` in custom checks
Badr-MOUFAD Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/0.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Version 0.4 (in progress)
-------------------------
- Add :ref:`GroupLasso Estimator <skglm.GroupLasso>` (PR: :gh:`228`)
- Add support and tutorial for positive coefficients to :ref:`Group Lasso Penalty <skglm.penalties.WeightedGroupL2>` (PR: :gh:`221`)
- Check compatibility with datafit and penalty in solver (PR :gh:`137`)
- Add support to weight samples in the quadratic datafit :ref:`Weighted Quadratic Datafit <skglm.datafit.WeightedQuadratic>` (PR: :gh:`258`)


Expand Down
2 changes: 1 addition & 1 deletion skglm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.3.2dev'
__version__ = '0.4dev'

from skglm.estimators import ( # noqa F401
Lasso, WeightedLasso, ElasticNet, MCPRegression, MultiTaskLasso, LinearSVC,
Expand Down
3 changes: 0 additions & 3 deletions skglm/datafits/single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,9 +661,6 @@ def value(self, y, w, Xw):
def gradient_scalar(self, X, y, w, Xw, j):
return X[:, j] @ (1 - y * np.exp(-Xw)) / len(y)

def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
pass

def intercept_update_step(self, y, Xw):
return np.sum(self.raw_grad(y, Xw))

Expand Down
37 changes: 11 additions & 26 deletions skglm/experimental/pdcd_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from scipy.sparse import issparse

from numba import njit
from skglm.utils.jit_compilation import compiled_clone
from skglm.solvers import BaseSolver

from sklearn.exceptions import ConvergenceWarning


class PDCD_WS:
class PDCD_WS(BaseSolver):
r"""Primal-Dual Coordinate Descent solver with working sets.

It solves
Expand Down Expand Up @@ -78,6 +79,9 @@ class PDCD_WS:
https://arxiv.org/abs/2204.07826
"""

_datafit_required_attr = ('prox_conjugate',)
_penalty_required_attr = ("prox_1d",)

def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None,
p0=100, tol=1e-6, verbose=False):
self.max_iter = max_iter
Expand All @@ -87,11 +91,7 @@ def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None,
self.tol = tol
self.verbose = verbose

def solve(self, X, y, datafit_, penalty_, w_init=None, Xw_init=None):
if issparse(X):
raise ValueError("Sparse matrices are not yet support in PDCD_WS solver.")

datafit, penalty = PDCD_WS._validate_init(datafit_, penalty_)
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
n_samples, n_features = X.shape

# init steps
Expand Down Expand Up @@ -196,27 +196,12 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty,
if stop_crit_in <= tol_in:
break

@staticmethod
def _validate_init(datafit_, penalty_):
# validate datafit
missing_attrs = []
for attr in ('prox_conjugate', 'subdiff_distance'):
if not hasattr(datafit_, attr):
missing_attrs.append(f"`{attr}`")

if len(missing_attrs):
raise AttributeError(
"Datafit is not compatible with PDCD_WS solver.\n"
"Datafit must implement `prox_conjugate` and `subdiff_distance`.\n"
f"Missing {' and '.join(missing_attrs)}."
def custom_checks(self, X, y, datafit, penalty):
if issparse(X):
raise ValueError(
"Sparse matrices are not yet supported in `PDCD_WS` solver."
)

# jit compile classes
compiled_datafit = compiled_clone(datafit_)
compiled_penalty = compiled_clone(penalty_)

return compiled_datafit, compiled_penalty


@njit
def _scores_primal(X, w, z, penalty, primal_steps, ws):
Expand Down
6 changes: 5 additions & 1 deletion skglm/experimental/tests/test_quantile_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from skglm.penalties import L1
from skglm.experimental.pdcd_ws import PDCD_WS
from skglm.experimental.quantile_regression import Pinball
from skglm.utils.jit_compilation import compiled_clone

from skglm.utils.data import make_correlated_data
from sklearn.linear_model import QuantileRegressor
Expand All @@ -21,9 +22,12 @@ def test_PDCD_WS(quantile_level):
alpha_max = norm(X.T @ (np.sign(y)/2 + (quantile_level - 0.5)), ord=np.inf)
alpha = alpha_max / 5

datafit = compiled_clone(Pinball(quantile_level))
penalty = compiled_clone(L1(alpha))

w = PDCD_WS(
dual_init=np.sign(y)/2 + (quantile_level - 0.5)
).solve(X, y, Pinball(quantile_level), L1(alpha))[0]
).solve(X, y, datafit, penalty)[0]

clf = QuantileRegressor(
quantile=quantile_level,
Expand Down
6 changes: 5 additions & 1 deletion skglm/experimental/tests/test_sqrt_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from skglm.experimental.sqrt_lasso import (SqrtLasso, SqrtQuadratic,
_chambolle_pock_sqrt)
from skglm.experimental.pdcd_ws import PDCD_WS
from skglm.utils.jit_compilation import compiled_clone


def test_alpha_max():
Expand Down Expand Up @@ -69,7 +70,10 @@ def test_PDCD_WS(with_dual_init):

dual_init = y / norm(y) if with_dual_init else None

w = PDCD_WS(dual_init=dual_init).solve(X, y, SqrtQuadratic(), L1(alpha))[0]
datafit = compiled_clone(SqrtQuadratic())
penalty = compiled_clone(L1(alpha))

w = PDCD_WS(dual_init=dual_init).solve(X, y, datafit, penalty)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how was this working before if the datafit and penalties were not compiled ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They were compiled inside the solver.

clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)
np.testing.assert_allclose(clf.coef_, w, atol=1e-6)

Expand Down
11 changes: 0 additions & 11 deletions skglm/penalties/block_separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,17 +458,6 @@ def prox_1group(self, value, stepsize, g):
res = ST_vec(value, self.alpha * stepsize * self.weights_features[g])
return BST(res, self.alpha * stepsize * self.weights_groups[g])

def subdiff_distance(self, w, grad_ws, ws):
"""Compute distance to the subdifferential at ``w`` of negative gradient.

Refer to :ref:`subdiff_positive_group_lasso` for details of the derivation.

Note:
----
``grad_ws`` is a stacked array of gradients ``[grad_ws_1, grad_ws_2, ...]``.
"""
raise NotImplementedError("Too hard for now")

def is_penalized(self, n_groups):
return np.ones(n_groups, dtype=np.bool_)

Expand Down
10 changes: 0 additions & 10 deletions skglm/penalties/non_separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,3 @@ def prox_vec(self, x, stepsize):
prox[sorted_indices] = prox_SLOPE(abs_x[sorted_indices], alphas * stepsize)

return np.sign(x) * prox

def prox_1d(self, value, stepsize, j):
raise ValueError(
"No coordinate-wise proximal operator for SLOPE. Use `prox_vec` instead."
)

def subdiff_distance(self, w, grad, ws):
return ValueError(
"No subdifferential distance for SLOPE. Use `opt_strategy='fixpoint'`"
)
21 changes: 20 additions & 1 deletion skglm/solvers/anderson_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from skglm.solvers.base import BaseSolver
from skglm.utils.anderson import AndersonAcceleration
from skglm.utils.validation import check_attrs


class AndersonCD(BaseSolver):
Expand Down Expand Up @@ -47,6 +48,9 @@ class AndersonCD(BaseSolver):
code: https://github.com/mathurinm/andersoncd
"""

_datafit_required_attr = ("get_lipschitz", "gradient_scalar")
_penalty_required_attr = ("prox_1d",)

def __init__(self, max_iter=50, max_epochs=50_000, p0=10,
tol=1e-4, ws_strategy="subdiff", fit_intercept=True,
warm_start=False, verbose=0):
Expand All @@ -59,7 +63,7 @@ def __init__(self, max_iter=50, max_epochs=50_000, p0=10,
self.warm_start = warm_start
self.verbose = verbose

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
if self.ws_strategy not in ("subdiff", "fixpoint"):
raise ValueError(
'Unsupported value for self.ws_strategy:', self.ws_strategy)
Expand Down Expand Up @@ -269,6 +273,21 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None,
results += (n_iters,)
return results

def custom_checks(self, X, y, datafit, penalty):
# check datafit support sparse data
check_attrs(
datafit, solver=self,
required_attr=self._datafit_required_attr,
support_sparse=sparse.issparse(X)
)

# ws strategy
if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"):
raise AttributeError(
"Penalty must implement `subdiff_distance` "
"to use ws_strategy='subdiff' in solver AndersonCD."
)


@njit
def _cd_epoch(X, y, w, Xw, lc, datafit, penalty, ws):
Expand Down
83 changes: 79 additions & 4 deletions skglm/solvers/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
from abc import abstractmethod
from abc import abstractmethod, ABC
from skglm.utils.validation import check_attrs


class BaseSolver():
"""Base class for solvers."""
class BaseSolver(ABC):
"""Base class for solvers.

Attributes
----------
_datafit_required_attr : list
List of attributes that must be implemented in Datafit.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make these public @BadrMOUFAD ? Maybe I'm missing a specific reason

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also typo missing "that must BE" here and below

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make these public @BadrMOUFAD ? Maybe I'm missing a specific reason

I think these two attributes should be read only.

While there is a way to make attributes read only, namely using the property decorator, I believe it adds two much complexity to the code and hence doesn’t serve our goal to make components implementation user-friendly.

I opted for the “start with underscore” naming convention to make variables private to signal to the user that these are attributes to not mess up with.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this makes sense too !


_penalty_required_attr : list
List of attributes that must be implemented in Penalty.

Notes
-----
For required attributes, if an attribute is given as a list of attributes
it means at least one of them should be implemented.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had trouble understanding this, an example help (in which case do we want to check that one of several attributes is present?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I 100% agree with you @mathurinm, I should have accompanied the docs with an example

In Fista solver, this

    _datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar"))
    _penalty_required_attr = (("prox_1d", "prox_vec"),)

is interpreted as

  • datafit is required to have get_global_lipschitz and (gradient or gradient_scalar)
  • penalty is required to have prox_1d or prox_vec

This is the way I implemented check_obj_solver_attr function: whenever attributes are wrapped in parenthesis, it is interpreted as the “or” operator and a comma is interpreted the “and” operator.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be a bit more explciit like : "if an element of the required attribute list is itself a list, it means that at least of theses attributes must be implemented
e.g. _datafit_required_attr = ("a", ("b", "c")) means that the datafit must implement a, together with at least one of b and c" ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed

For instance, if

_datafit_required_attr = (
"get_global_lipschitz",
("gradient", "gradient_scalar")
)

it mean datafit must implement the methods ``get_global_lipschitz``
and (``gradient`` or ``gradient_scaler``).
"""

_datafit_required_attr: list
_penalty_required_attr: list

@abstractmethod
def solve(self, X, y, datafit, penalty, w_init, Xw_init):
def _solve(self, X, y, datafit, penalty, w_init, Xw_init):
"""Solve an optimization problem.

Parameters
Expand Down Expand Up @@ -39,3 +66,51 @@ def solve(self, X, y, datafit, penalty, w_init, Xw_init):
stop_crit : float
Value of stopping criterion at convergence.
"""

def custom_checks(self, X, y, datafit, penalty):
"""Ensure the solver is suited for the `datafit` + `penalty` problem.

This method includes extra checks to perform
aside from checking attributes compatibility.

Parameters
----------
X : array, shape (n_samples, n_features)
Training data.

y : array, shape (n_samples,)
Target values.
Copy link
Collaborator

@QB3 QB3 Nov 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

custom_compatibility_check currently depends on the target y, but the target is never used in the check.
Should we remove y from this function, or do you see cases where it will be needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the validation of datafit/penalty depends on the data, for instance when X is sparse we should check that the datafit implements _sparse methods, IMO it is better to pass in both X, y

For now, we can settle for X only, but that means adding y later if we need it which would alter the API.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @Badr-MOUFAD here, even if not needed at the moment it's not too hard to see cases where this would happen, and an API change will be painful


datafit : instance of BaseDatafit
Datafit.

penalty : instance of BasePenalty
Penalty.
"""
pass

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None,
*, run_checks=True):
"""Solve the optimization problem after validating its compatibility.

A proxy of ``_solve`` method that implicitly ensures the compatibility
of ``datafit`` and ``penalty`` with the solver.

Examples
--------
>>> ...
>>> coefs, obj_out, stop_crit = solver.solve(X, y, datafit, penalty)
"""
if run_checks:
self._validate(X, y, datafit, penalty)

return self._solve(X, y, datafit, penalty, w_init, Xw_init)

def _validate(self, X, y, datafit, penalty):
# execute: `custom_checks` then check attributes
self.custom_checks(X, y, datafit, penalty)

# do not check for sparse support here, make the check at the solver level
# some solvers like ProxNewton don't require methods for sparse support
check_attrs(datafit, self, self._datafit_required_attr)
check_attrs(penalty, self, self._penalty_required_attr)
40 changes: 26 additions & 14 deletions skglm/solvers/fista.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from skglm.solvers.base import BaseSolver
from skglm.solvers.common import construct_grad, construct_grad_sparse
from skglm.utils.prox_funcs import _prox_vec
from skglm.utils.validation import check_attrs


class FISTA(BaseSolver):
Expand All @@ -27,6 +28,9 @@ class FISTA(BaseSolver):
https://epubs.siam.org/doi/10.1137/080716542
"""

_datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar"))
_penalty_required_attr = (("prox_1d", "prox_vec"),)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does FISTA need prox_1D, it is not used in the code below

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in _prox_vec which take as argument the penalty (cf. https://github.com/PABannier/skglm/blob/e687bc2ecaacfe920b9aaad3e33e1f0cbdbac683/skglm/solvers/fista.py#L83)

The algorithm works if penalty has either of prox_1d or prox_vec.
(for reference #137 (comment))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, thanks !


def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0):
self.max_iter = max_iter
self.tol = tol
Expand All @@ -35,7 +39,7 @@ def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0):
self.fit_intercept = False # needed to be passed to GeneralizedLinearEstimator
self.warm_start = False

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
p_objs_out = []
n_samples, n_features = X.shape
all_features = np.arange(n_features)
Expand All @@ -46,19 +50,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
z = w_init.copy() if w_init is not None else np.zeros(n_features)
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)

try:
if X_is_sparse:
lipschitz = datafit.get_global_lipschitz_sparse(
X.data, X.indptr, X.indices, y
)
else:
lipschitz = datafit.get_global_lipschitz(X, y)
except AttributeError as e:
sparse_suffix = '_sparse' if X_is_sparse else ''

raise Exception(
"Datafit is not compatible with FISTA solver.\n Datafit must "
f"implement `get_global_lipschitz{sparse_suffix}` method") from e
if X_is_sparse:
lipschitz = datafit.get_global_lipschitz_sparse(
X.data, X.indptr, X.indices, y
)
else:
lipschitz = datafit.get_global_lipschitz(X, y)

for n_iter in range(self.max_iter):
t_old = t_new
Expand Down Expand Up @@ -111,3 +108,18 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
print(f"Stopping criterion max violation: {stop_crit:.2e}")
break
return w, np.array(p_objs_out), stop_crit

def custom_checks(self, X, y, datafit, penalty):
# check datafit support sparse data
check_attrs(
datafit, solver=self,
required_attr=self._datafit_required_attr,
support_sparse=issparse(X)
)

# optimality check
if self.opt_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"):
raise AttributeError(
"Penalty must implement `subdiff_distance` "
"to use `opt_strategy='subdiff'` in Fista solver."
)
Loading
Loading