Skip to content

Commit 306826f

Browse files
maikiaglemaitreNicolasHugagramfortogrisel
authored
MRG Deprecates 'normalize' in LinearRegression (_base.py) (scikit-learn#17743)
Co-authored-by: Guillaume Lemaitre <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Alexandre Gramfort <[email protected]> Co-authored-by: Olivier Grisel <[email protected]>
1 parent 8ea176a commit 306826f

File tree

4 files changed

+260
-6
lines changed

4 files changed

+260
-6
lines changed

doc/whats_new/v1.0.rst

+12
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,18 @@ Changelog
9494
Use ``var_`` instead.
9595
:pr:`18842` by :user:`Hong Shao Yang <hongshaoyang>`.
9696

97+
- |API|: The parameter ``normalize`` of :class:`linear_model.LinearRegression`
98+
is deprecated and will be removed in 1.2.
99+
Motivation for this deprecation: ``normalize`` parameter did not take any
100+
effect if ``fit_intercept`` was set to False and therefore was deemed
101+
confusing.
102+
The behavior of the deprecated LinearRegression(normalize=True) can be
103+
reproduced with :class:`~sklearn.pipeline.Pipeline` with
104+
:class:`~sklearn.preprocessing.StandardScaler`as follows:
105+
make_pipeline(StandardScaler(with_mean=False), LinearRegression()).
106+
:pr:`17743` by :user:`Maria Telenczuk <maikia>` and
107+
:user:`Alexandre Gramfort <agramfort>`.
108+
97109
Code and Documentation Contributors
98110
-----------------------------------
99111

sklearn/linear_model/_base.py

+105-5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# Lars Buitinck
1212
# Maryan Morel <[email protected]>
1313
# Giorgio Patrini <[email protected]>
14+
# Maria Telenczuk <https://github.com/maikia>
1415
# License: BSD 3 clause
1516

1617
from abc import ABCMeta, abstractmethod
@@ -49,6 +50,94 @@
4950
# intercept oscillation.
5051

5152

53+
# FIXME in 1.2: parameter 'normalize' should be removed from linear models
54+
# in cases where now normalize=False. The default value of 'normalize' should
55+
# be changed to False in linear models where now normalize=True
56+
def _deprecate_normalize(normalize, default, estimator_name):
57+
""" Normalize is to be deprecated from linear models and a use of
58+
a pipeline with a StandardScaler is to be recommended instead.
59+
Here the appropriate message is selected to be displayed to the user
60+
depending on the default normalize value (as it varies between the linear
61+
models and normalize value selected by the user).
62+
63+
Parameters
64+
----------
65+
normalize : bool,
66+
normalize value passed by the user
67+
68+
default : bool,
69+
default normalize value used by the estimator
70+
71+
estimator_name : string,
72+
name of the linear estimator which calls this function.
73+
The name will be used for writing the deprecation warnings
74+
75+
Returns
76+
-------
77+
normalize : bool,
78+
normalize value which should further be used by the estimator at this
79+
stage of the depreciation process
80+
81+
Notes
82+
-----
83+
This function should be updated in 1.2 depending on the value of
84+
`normalize`:
85+
- True, warning: `normalize` was deprecated in 1.2 and will be removed in
86+
1.4. Suggest to use pipeline instead.
87+
- False, `normalize` was deprecated in 1.2 and it will be removed in 1.4.
88+
Leave normalize to its default value.
89+
- `deprecated` - this should only be possible with default == False as from
90+
1.2 `normalize` in all the linear models should be either removed or the
91+
default should be set to False.
92+
This function should be completely removed in 1.4.
93+
"""
94+
95+
if normalize not in [True, False, 'deprecated']:
96+
raise ValueError("Leave 'normalize' to its default value or set it "
97+
"to True or False")
98+
99+
if normalize == 'deprecated':
100+
_normalize = default
101+
else:
102+
_normalize = normalize
103+
104+
if default and normalize == 'deprecated':
105+
warnings.warn(
106+
"The default of 'normalize' will be set to False in version 1.2 "
107+
"and deprecated in version 1.4. \nPass normalize=False and use "
108+
"Pipeline with a StandardScaler in a preprocessing stage if you "
109+
"wish to reproduce the previous behavior:\n"
110+
"model = make_pipeline(StandardScaler(with_mean=False), \n"
111+
f"{estimator_name}(normalize=False))\n"
112+
"If you wish to use additional parameters in "
113+
"the fit() you can include them as follows:\n"
114+
"kwargs = {model.steps[-1][0] + "
115+
"'__<your_param_name>': <your_param_value>}\n"
116+
"model.fit(X, y, **kwargs)", FutureWarning
117+
)
118+
elif normalize != 'deprecated' and normalize and not default:
119+
warnings.warn(
120+
"'normalize' was deprecated in version 1.0 and will be "
121+
"removed in 1.2 \nIf you still wish to normalize use "
122+
"Pipeline with a StandardScaler in a preprocessing stage if you "
123+
"wish to reproduce the previous behavior:\n"
124+
"model = make_pipeline(StandardScaler(with_mean=False), "
125+
f"{estimator_name}()). \nIf you wish to use additional "
126+
"parameters in the fit() you can include them as follows: "
127+
"kwargs = {model.steps[-1][0] + "
128+
"'__<your_param_name>': <your_param_value>}\n"
129+
"model.fit(X, y, **kwargs)", FutureWarning
130+
)
131+
elif not normalize and not default:
132+
warnings.warn(
133+
"'normalize' was deprecated in version 1.0 and will be"
134+
" removed in 1.2 Don't set 'normalize' parameter"
135+
" and leave it to its default value", FutureWarning
136+
)
137+
138+
return _normalize
139+
140+
52141
def make_dataset(X, y, sample_weight, random_state=None):
53142
"""Create ``Dataset`` abstraction for sparse and dense inputs.
54143
@@ -407,6 +496,10 @@ class LinearRegression(MultiOutputMixin, RegressorMixin, LinearModel):
407496
:class:`~sklearn.preprocessing.StandardScaler` before calling ``fit``
408497
on an estimator with ``normalize=False``.
409498
499+
.. deprecated:: 1.0
500+
`normalize` was deprecated in version 1.0 and will be
501+
removed in 1.2.
502+
410503
copy_X : bool, default=True
411504
If True, X will be copied; else, it may be overwritten.
412505
@@ -476,8 +569,8 @@ class LinearRegression(MultiOutputMixin, RegressorMixin, LinearModel):
476569
array([16.])
477570
"""
478571
@_deprecate_positional_args
479-
def __init__(self, *, fit_intercept=True, normalize=False, copy_X=True,
480-
n_jobs=None, positive=False):
572+
def __init__(self, *, fit_intercept=True, normalize='deprecated',
573+
copy_X=True, n_jobs=None, positive=False):
481574
self.fit_intercept = fit_intercept
482575
self.normalize = normalize
483576
self.copy_X = copy_X
@@ -507,6 +600,11 @@ def fit(self, X, y, sample_weight=None):
507600
self : returns an instance of self.
508601
"""
509602

603+
_normalize = _deprecate_normalize(
604+
self.normalize, default=False,
605+
estimator_name=self.__class__.__name__
606+
)
607+
510608
n_jobs_ = self.n_jobs
511609

512610
accept_sparse = False if self.positive else ['csr', 'csc', 'coo']
@@ -519,7 +617,7 @@ def fit(self, X, y, sample_weight=None):
519617
dtype=X.dtype)
520618

521619
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
522-
X, y, fit_intercept=self.fit_intercept, normalize=self.normalize,
620+
X, y, fit_intercept=self.fit_intercept, normalize=_normalize,
523621
copy=self.copy_X, sample_weight=sample_weight,
524622
return_mean=True)
525623

@@ -651,10 +749,12 @@ def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy,
651749
check_input=check_input, sample_weight=sample_weight)
652750
if sample_weight is not None:
653751
X, y = _rescale_data(X, y, sample_weight=sample_weight)
752+
753+
# FIXME: 'normalize' to be removed in 1.2
654754
if hasattr(precompute, '__array__'):
655755
if (fit_intercept and not np.allclose(X_offset, np.zeros(n_features))
656-
or normalize and not np.allclose(X_scale,
657-
np.ones(n_features))):
756+
or normalize and not np.allclose(X_scale, np.ones(n_features)
757+
)):
658758
warnings.warn(
659759
"Gram matrix was provided but X was centered to fit "
660760
"intercept, or X was normalized : recomputing Gram matrix.",

sklearn/linear_model/tests/test_base.py

+86-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sklearn.utils.fixes import parse_version
1818

1919
from sklearn.linear_model import LinearRegression
20+
from sklearn.linear_model._base import _deprecate_normalize
2021
from sklearn.linear_model._base import _preprocess_data
2122
from sklearn.linear_model._base import _rescale_data
2223
from sklearn.linear_model._base import make_dataset
@@ -106,6 +107,7 @@ def test_raises_value_error_if_positive_and_sparse():
106107
with pytest.raises(TypeError, match=error_msg):
107108
reg.fit(X, y)
108109

110+
109111
def test_raises_value_error_if_sample_weights_greater_than_1d():
110112
# Sample weights must be either scalar or 1D
111113

@@ -149,6 +151,59 @@ def test_fit_intercept():
149151
lr3_without_intercept.coef_.ndim)
150152

151153

154+
def test_error_on_wrong_normalize():
155+
normalize = 'wrong'
156+
default = True
157+
error_msg = "Leave 'normalize' to its default"
158+
with pytest.raises(ValueError, match=error_msg):
159+
_deprecate_normalize(normalize, default, 'estimator')
160+
ValueError
161+
162+
163+
@pytest.mark.parametrize('normalize', [True, False, 'deprecated'])
164+
@pytest.mark.parametrize('default', [True, False])
165+
# FIXME update test in 1.2 for new versions
166+
def test_deprecate_normalize(normalize, default):
167+
# test all possible case of the normalize parameter deprecation
168+
if not default:
169+
if normalize == 'deprecated':
170+
# no warning
171+
output = default
172+
expected = None
173+
warning_msg = []
174+
else:
175+
output = normalize
176+
expected = FutureWarning
177+
warning_msg = ['1.2']
178+
if not normalize:
179+
warning_msg.append('default value')
180+
else:
181+
warning_msg.append('StandardScaler(')
182+
elif default:
183+
if normalize == 'deprecated':
184+
# warning to pass False and use StandardScaler
185+
output = default
186+
expected = FutureWarning
187+
warning_msg = ['False', '1.2', 'StandardScaler(']
188+
else:
189+
# no warning
190+
output = normalize
191+
expected = None
192+
warning_msg = []
193+
194+
with pytest.warns(expected) as record:
195+
_normalize = _deprecate_normalize(normalize, default, 'estimator')
196+
assert _normalize == output
197+
198+
n_warnings = 0 if expected is None else 1
199+
assert len(record) == n_warnings
200+
if n_warnings:
201+
assert all([
202+
warning in str(record[0].message)
203+
for warning in warning_msg
204+
])
205+
206+
152207
def test_linear_regression_sparse(random_state=0):
153208
# Test that linear regression also works with sparse data
154209
random_state = check_random_state(random_state)
@@ -165,6 +220,35 @@ def test_linear_regression_sparse(random_state=0):
165220
assert_array_almost_equal(ols.predict(X) - y.ravel(), 0)
166221

167222

223+
@pytest.mark.parametrize(
224+
'normalize, n_warnings, warning',
225+
[(True, 1, FutureWarning),
226+
(False, 1, FutureWarning),
227+
("deprecated", 0, None)]
228+
)
229+
# FIXME remove test in 1.4
230+
def test_linear_regression_normalize_deprecation(
231+
normalize, n_warnings, warning
232+
):
233+
# check that we issue a FutureWarning when normalize was set in
234+
# LinearRegression
235+
rng = check_random_state(0)
236+
n_samples = 200
237+
n_features = 2
238+
X = rng.randn(n_samples, n_features)
239+
X[X < 0.1] = 0.0
240+
y = rng.rand(n_samples)
241+
242+
model = LinearRegression(normalize=normalize)
243+
with pytest.warns(warning) as record:
244+
model.fit(X, y)
245+
assert len(record) == n_warnings
246+
if n_warnings:
247+
assert "'normalize' was deprecated" in str(record[0].message)
248+
249+
250+
# FIXME: 'normalize' to be removed in 1.2 in LinearRegression
251+
@pytest.mark.filterwarnings("ignore:'normalize' was deprecated")
168252
@pytest.mark.parametrize('normalize', [True, False])
169253
@pytest.mark.parametrize('fit_intercept', [True, False])
170254
def test_linear_regression_sparse_equal_dense(normalize, fit_intercept):
@@ -303,8 +387,9 @@ def test_linear_regression_pd_sparse_dataframe_warning():
303387
df[str(col)] = arr
304388

305389
msg = "pandas.DataFrame with sparse columns found."
390+
391+
reg = LinearRegression()
306392
with pytest.warns(UserWarning, match=msg):
307-
reg = LinearRegression()
308393
reg.fit(df.iloc[:, 0:2], df.iloc[:, 3])
309394

310395
# does not warn when the whole dataframe is sparse

sklearn/linear_model/tests/test_coordinate_descent.py

+57
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sklearn.utils._testing import assert_warns_message
2727
from sklearn.utils._testing import ignore_warnings
2828
from sklearn.utils._testing import assert_array_equal
29+
from sklearn.utils._testing import _convert_container
2930
from sklearn.utils._testing import TempMemmap
3031
from sklearn.utils.fixes import parse_version
3132

@@ -301,6 +302,8 @@ def test_lasso_cv_positive_constraint():
301302
assert min(clf_constrained.coef_) >= 0
302303

303304

305+
# FIXME: 'normalize' to be removed in 1.2
306+
@pytest.mark.filterwarnings("ignore:'normalize' was deprecated")
304307
@pytest.mark.parametrize(
305308
"LinearModel, params",
306309
[(Lasso, {"tol": 1e-16, "alpha": 0.1}),
@@ -384,6 +387,60 @@ def test_model_pipeline_same_as_normalize_true(LinearModel, params):
384387
assert_allclose(y_pred_normalize, y_pred_standardize)
385388

386389

390+
# FIXME: 'normalize' to be removed in 1.2
391+
@pytest.mark.filterwarnings("ignore:'normalize' was deprecated")
392+
@pytest.mark.parametrize(
393+
"estimator, is_sparse, with_mean",
394+
[(LinearRegression, True, False),
395+
(LinearRegression, False, True),
396+
(LinearRegression, False, False)]
397+
)
398+
def test_linear_model_sample_weights_normalize_in_pipeline(
399+
estimator, is_sparse, with_mean
400+
):
401+
# Test that the results for running linear regression LinearRegression with
402+
# sample_weight set and with normalize set to True gives similar results as
403+
# LinearRegression with no normalize in a pipeline with a StandardScaler
404+
# and set sample_weight.
405+
rng = np.random.RandomState(0)
406+
X, y = make_regression(n_samples=20, n_features=5, noise=1e-2,
407+
random_state=rng)
408+
# make sure the data is not centered to make the problem more
409+
# difficult
410+
X += 10
411+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5,
412+
random_state=rng)
413+
if is_sparse:
414+
X_train = sparse.csr_matrix(X_train)
415+
X_test = _convert_container(X_train, 'sparse')
416+
417+
sample_weight = rng.rand(X_train.shape[0])
418+
419+
# linear estimator with explicit sample_weight
420+
reg_with_normalize = estimator(normalize=True)
421+
reg_with_normalize.fit(X_train, y_train, sample_weight=sample_weight)
422+
423+
# linear estimator in a pipeline
424+
reg_with_scaler = make_pipeline(
425+
StandardScaler(with_mean=with_mean),
426+
estimator(normalize=False)
427+
)
428+
kwargs = {reg_with_scaler.steps[-1][0] + '__sample_weight':
429+
sample_weight}
430+
reg_with_scaler.fit(X_train, y_train, **kwargs)
431+
432+
y_pred_norm = reg_with_normalize.predict(X_test)
433+
y_pred_pip = reg_with_scaler.predict(X_test)
434+
435+
assert_allclose(
436+
reg_with_normalize.coef_ * reg_with_scaler[0].scale_,
437+
reg_with_scaler[1].coef_
438+
)
439+
assert_allclose(y_pred_norm, y_pred_pip)
440+
441+
442+
# FIXME: 'normalize' to be removed in 1.2
443+
@pytest.mark.filterwarnings("ignore:'normalize' was deprecated")
387444
@pytest.mark.parametrize(
388445
"LinearModel, params",
389446
[(Lasso, {"tol": 1e-16, "alpha": 0.1}),

0 commit comments

Comments
 (0)