Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Add support for GLasso and Adaptive (reweighted) GLasso #280

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions examples/plot_graphical_lasso.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
from sklearn.datasets import make_sparse_spd_matrix
from sklearn.covariance import GraphicalLasso as skGraphicalLasso

from skglm.estimators import GraphicalLasso, AdaptiveGraphicalLasso
from skglm.utils.data import generate_GraphicalLasso_data

# Data
p = 20
n = 100
S, Theta_true, alpha_max = generate_GraphicalLasso_data(n, p)

alphas = alpha_max*np.geomspace(1, 1e-3, num=30)


penalties = [
"L1",
"R-L1",
]

models_tol = 1e-4
models = [
GraphicalLasso(algo="mazumder",
warm_start=True, tol=models_tol),
AdaptiveGraphicalLasso(warm_start=True, n_reweights=10, tol=models_tol),

]

my_glasso_nmses = {penalty: [] for penalty in penalties}
my_glasso_f1_scores = {penalty: [] for penalty in penalties}

sk_glasso_nmses = []
sk_glasso_f1_scores = []


for i, (penalty, model) in enumerate(zip(penalties, models)):
print(penalty)
for alpha_idx, alpha in enumerate(alphas):
print(f"======= alpha {alpha_idx+1}/{len(alphas)} =======")
model.alpha = alpha
model.fit(S)
Theta = model.precision_

my_nmse = norm(Theta - Theta_true)**2 / norm(Theta_true)**2

my_f1_score = f1_score(Theta.flatten() != 0.,
Theta_true.flatten() != 0.)
print(f"NMSE: {my_nmse:.3f}")
print(f"F1 : {my_f1_score:.3f}")

my_glasso_nmses[penalty].append(my_nmse)
my_glasso_f1_scores[penalty].append(my_f1_score)


plt.close('all')
fig, ax = plt.subplots(2, 1, sharex=True, figsize=(
[12.6, 4.63]), layout="constrained")
cmap = plt.get_cmap("tab10")
for i, penalty in enumerate(penalties):

ax[0].semilogx(alphas/alpha_max,
my_glasso_nmses[penalty],
color=cmap(i),
linewidth=2.,
label=penalty)
min_nmse = np.argmin(my_glasso_nmses[penalty])
ax[0].vlines(
x=alphas[min_nmse] / alphas[0],
ymin=0,
ymax=np.min(my_glasso_nmses[penalty]),
linestyle='--',
color=cmap(i))
line0 = ax[0].plot(
[alphas[min_nmse] / alphas[0]],
0,
clip_on=False,
marker='X',
color=cmap(i),
markersize=12)

ax[1].semilogx(alphas/alpha_max,
my_glasso_f1_scores[penalty],
linewidth=2.,
color=cmap(i))
max_f1 = np.argmax(my_glasso_f1_scores[penalty])
ax[1].vlines(
x=alphas[max_f1] / alphas[0],
ymin=0,
ymax=np.max(my_glasso_f1_scores[penalty]),
linestyle='--',
color=cmap(i))
line1 = ax[1].plot(
[alphas[max_f1] / alphas[0]],
0,
clip_on=False,
marker='X',
markersize=12,
color=cmap(i))


ax[0].set_title(f"{p=},{n=}", fontsize=18)
ax[0].set_ylabel("NMSE", fontsize=18)
ax[1].set_ylabel("F1 score", fontsize=18)
ax[1].set_xlabel(f"$\lambda / \lambda_\mathrm{{max}}$", fontsize=18)

ax[0].legend(fontsize=14)
ax[0].grid(which='both', alpha=0.9)
ax[1].grid(which='both', alpha=0.9)
# plt.show(block=False)
plt.show()
189 changes: 177 additions & 12 deletions skglm/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from skglm.utils.jit_compilation import compiled_clone
from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD
from skglm.datafits import (Cox, Quadratic, Logistic, QuadraticSVC,
QuadraticMultiTask, QuadraticGroup,)
QuadraticMultiTask, QuadraticGroup, QuadraticHessian)
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2,
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
from skglm.utils.data import grp_converter
Expand Down Expand Up @@ -126,7 +126,8 @@ def _glm_fit(X, y, model, datafit, penalty, solver):
w = np.zeros(n_features + fit_intercept, dtype=X_.dtype)
Xw = np.zeros(n_samples, dtype=X_.dtype)
else: # multitask
w = np.zeros((n_features + fit_intercept, y.shape[1]), dtype=X_.dtype)
w = np.zeros((n_features + fit_intercept,
y.shape[1]), dtype=X_.dtype)
Xw = np.zeros(y.shape, dtype=X_.dtype)

# check consistency of weights for WeightedL1
Expand Down Expand Up @@ -576,7 +577,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
raise ValueError("The number of weights must match the number of \
features. Got %s, expected %s." % (
len(weights), X.shape[1]))
penalty = compiled_clone(WeightedL1(self.alpha, weights, self.positive))
penalty = compiled_clone(WeightedL1(
self.alpha, weights, self.positive))
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
solver = AndersonCD(
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
Expand All @@ -601,7 +603,8 @@ def fit(self, X, y):
Fitted estimator.
"""
if self.weights is None:
warnings.warn('Weights are not provided, fitting with Lasso penalty')
warnings.warn(
'Weights are not provided, fitting with Lasso penalty')
penalty = L1(self.alpha, self.positive)
else:
penalty = WeightedL1(self.alpha, self.weights, self.positive)
Expand Down Expand Up @@ -734,7 +737,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
The number of iterations along the path. If return_n_iter is set to
``True``.
"""
penalty = compiled_clone(L1_plus_L2(self.alpha, self.l1_ratio, self.positive))
penalty = compiled_clone(L1_plus_L2(
self.alpha, self.l1_ratio, self.positive))
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
solver = AndersonCD(
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
Expand Down Expand Up @@ -912,7 +916,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
f"Got {len(self.weights)}, expected {X.shape[1]}."
)
penalty = compiled_clone(
WeightedMCPenalty(self.alpha, self.gamma, self.weights, self.positive)
WeightedMCPenalty(self.alpha, self.gamma,
self.weights, self.positive)
)
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
solver = AndersonCD(
Expand Down Expand Up @@ -1307,7 +1312,8 @@ def fit(self, X, y):
# copy/paste from https://github.com/scikit-learn/scikit-learn/blob/ \
# 23ff51c07ebc03c866984e93c921a8993e96d1f9/sklearn/utils/ \
# estimator_checks.py#L3886
raise ValueError("requires y to be passed, but the target y is None")
raise ValueError(
"requires y to be passed, but the target y is None")
y = check_array(
y,
accept_sparse=False,
Expand All @@ -1322,7 +1328,8 @@ def fit(self, X, y):
f"two columns. Got one column.\nAssuming that `y` "
"is the vector of times and there is no censoring."
)
y = np.column_stack((y, np.ones_like(y))).astype(X.dtype, order="F")
y = np.column_stack((y, np.ones_like(y))).astype(
X.dtype, order="F")
elif y.shape[1] > 2:
raise ValueError(
f"{repr(self)} requires the vector of response `y` to have "
Expand All @@ -1347,7 +1354,8 @@ def fit(self, X, y):

# init solver
if self.l1_ratio == 0.:
solver = LBFGS(max_iter=self.max_iter, tol=self.tol, verbose=self.verbose)
solver = LBFGS(max_iter=self.max_iter,
tol=self.tol, verbose=self.verbose)
else:
solver = ProxNewton(
max_iter=self.max_iter, tol=self.tol, verbose=self.verbose,
Expand Down Expand Up @@ -1485,7 +1493,8 @@ def fit(self, X, Y):
if not self.warm_start or not hasattr(self, "coef_"):
self.coef_ = None

datafit_jit = compiled_clone(QuadraticMultiTask(), X.dtype == np.float32)
datafit_jit = compiled_clone(
QuadraticMultiTask(), X.dtype == np.float32)
penalty_jit = compiled_clone(L2_1(self.alpha), X.dtype == np.float32)

solver = MultiTaskBCD(
Expand Down Expand Up @@ -1540,7 +1549,8 @@ def path(self, X, Y, alphas, coef_init=None, return_n_iter=False, **params):
The number of iterations along the path. If return_n_iter is set to
``True``.
"""
datafit = compiled_clone(QuadraticMultiTask(), to_float32=X.dtype == np.float32)
datafit = compiled_clone(QuadraticMultiTask(),
to_float32=X.dtype == np.float32)
penalty = compiled_clone(L2_1(self.alpha))
solver = MultiTaskBCD(
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
Expand Down Expand Up @@ -1664,7 +1674,8 @@ def fit(self, X, y):
"The total number of group members must equal the number of features. "
f"Got {n_features}, expected {X.shape[1]}.")

weights = np.ones(len(group_sizes)) if self.weights is None else self.weights
weights = np.ones(
len(group_sizes)) if self.weights is None else self.weights
group_penalty = WeightedGroupL2(alpha=self.alpha, grp_ptr=grp_ptr,
grp_indices=grp_indices, weights=weights,
positive=self.positive)
Expand All @@ -1675,3 +1686,157 @@ def fit(self, X, y):
verbose=self.verbose)

return _glm_fit(X, y, self, quad_group, group_penalty, solver)


class GraphicalLasso():
def __init__(self,
alpha=1.,
weights=None,
algo="banerjee",
max_iter=1000,
tol=1e-8,
warm_start=False,
):
self.alpha = alpha
self.weights = weights
self.algo = algo
self.max_iter = max_iter
self.tol = tol
self.warm_start = warm_start

def fit(self, S):
p = S.shape[-1]
indices = np.arange(p)

if self.weights is None:
Weights = np.ones((p, p))
else:
Weights = self.weights
if not np.allclose(Weights, Weights.T):
raise ValueError("Weights should be symmetric.")

if self.warm_start and hasattr(self, "precision_"):
if self.algo == "banerjee":
raise ValueError(
"Banerjee does not support warm start for now.")
Theta = self.precision_
W = self.covariance_
else:
W = S.copy() # + alpha*np.eye(p)
Theta = np.linalg.pinv(W, hermitian=True)

datafit = compiled_clone(QuadraticHessian())
penalty = compiled_clone(
WeightedL1(alpha=self.alpha, weights=Weights[0, :-1]))

solver = AndersonCD(warm_start=True,
fit_intercept=False,
ws_strategy="fixpoint")

for it in range(self.max_iter):
Theta_old = Theta.copy()
for col in range(p):
indices_minus_col = np.concatenate(
[indices[:col], indices[col + 1:]])
_11 = indices_minus_col[:, None], indices_minus_col[None]
_12 = indices_minus_col, col
_21 = col, indices_minus_col
_22 = col, col

W_11 = W[_11]
w_12 = W[_12]
w_22 = W[_22]
s_12 = S[_12]
s_22 = S[_22]

penalty.weights = Weights[_12]

if self.algo == "banerjee":
w_init = Theta[_12]/Theta[_22]
Xw_init = W_11 @ w_init
Q = W_11
elif self.algo == "mazumder":
inv_Theta_11 = W_11 - np.outer(w_12, w_12)/w_22
Q = inv_Theta_11
w_init = Theta[_12] * w_22
Xw_init = inv_Theta_11 @ w_init
else:
raise ValueError(f"Unsupported algo {self.algo}")

beta, _, _ = solver._solve(
Q,
s_12,
datafit,
penalty,
w_init=w_init,
Xw_init=Xw_init,
)

if self.algo == "banerjee":
w_12 = -W_11 @ beta
W[_12] = w_12
W[_21] = w_12
Theta[_22] = 1/(s_22 + beta @ w_12)
Theta[_12] = beta*Theta[_22]
else: # mazumder
theta_12 = beta / s_22
theta_22 = 1/s_22 + theta_12 @ inv_Theta_11 @ theta_12

Theta[_12] = theta_12
Theta[_21] = theta_12
Theta[_22] = theta_22

w_22 = 1/(theta_22 - theta_12 @ inv_Theta_11 @ theta_12)
w_12 = -w_22*inv_Theta_11 @ theta_12
W_11 = inv_Theta_11 + np.outer(w_12, w_12)/w_22
W[_11] = W_11
W[_12] = w_12
W[_21] = w_12
W[_22] = w_22

if np.linalg.norm(Theta - Theta_old) < self.tol:
print(f"Weighted Glasso converged at CD epoch {it + 1}")
break
else:
print(f"Not converged at epoch {it + 1}, "
f"diff={np.linalg.norm(Theta - Theta_old):.2e}")
self.precision_, self.covariance_ = Theta, W
self.n_iter_ = it + 1

return self


class AdaptiveGraphicalLasso():
def __init__(
self,
alpha=1.,
n_reweights=5,
max_iter=1000,
tol=1e-8,
warm_start=False,
# verbose=False,
):
self.alpha = alpha
self.n_reweights = n_reweights
self.max_iter = max_iter
self.tol = tol
self.warm_start = warm_start

def fit(self, S):
glasso = GraphicalLasso(
alpha=self.alpha, algo="mazumder", max_iter=self.max_iter,
tol=self.tol, warm_start=True)
Weights = np.ones(S.shape)
self.n_iter_ = []
for it in range(self.n_reweights):
glasso.weights = Weights
glasso.fit(S)
Theta = glasso.precision_
Weights = 1/(np.abs(Theta) + 1e-10)
self.n_iter_.append(glasso.n_iter_)
# TODO print losses for original problem?
glasso.covariance_ = np.linalg.pinv(Theta, hermitian=True)
self.precision_ = glasso.precision_
self.covariance_ = glasso.covariance_

return self
Loading
Loading