Skip to content
This repository was archived by the owner on Dec 8, 2024. It is now read-only.

Commit 564ed86

Browse files
larsbratholmandersx
authored andcommitted
Changed linear fit to lasso regression (#93)
This was done to help fitting of underdetermined cases, such as when the chemical composition never or rarely changes (md snapshots etc). Also a slight restructure was done to avoid repeated code.
1 parent 04fae02 commit 564ed86

File tree

1 file changed

+59
-33
lines changed

1 file changed

+59
-33
lines changed

qml/qmlearn/preprocessing.py

+59-33
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
import numpy as np
2828
from sklearn.base import BaseEstimator
29-
from sklearn.linear_model import LinearRegression
29+
from sklearn.linear_model import Lasso
3030

3131
from .data import Data
3232
from ..utils import is_numeric_array, get_unique, is_positive_integer_or_zero_array
@@ -50,7 +50,7 @@ def __init__(self, data=None, elements='auto'):
5050
self.elements = elements
5151

5252
# Initialize model
53-
self.model = LinearRegression()
53+
self.model = Lasso(alpha=1e-9)
5454

5555
def _preprocess_input(self, X):
5656
"""
@@ -117,15 +117,25 @@ def _check_data(self, X):
117117
print("Error: Expected Data object to have non-empty attribute 'energies'" % self.__class__.__name__)
118118
raise SystemExit
119119

120-
121120
def _set_data(self, data):
122121
if data:
123122
self._check_data(data)
124123
self.data = data
125124

126-
def fit_transform(self, X, y=None):
125+
def _parse_input(self, X, y):
126+
if not isinstance(X, Data) and y is not None:
127+
data = None
128+
nuclear_charges = X
129+
else:
130+
data = self._preprocess_input(X)
131+
nuclear_charges = data.nuclear_charges[data._indices]
132+
y = data.energies[data._indices]
133+
134+
return data, nuclear_charges, y
135+
136+
def fit(self, X, y=None):
127137
"""
128-
Fit and transform the data with a linear model.
138+
Fit the data with a linear model.
129139
Supports three different types of input.
130140
1) X is a list of nuclear charges and y is values to transform.
131141
2) X is an array of indices of which to transform.
@@ -135,27 +145,58 @@ def fit_transform(self, X, y=None):
135145
:type X: list
136146
:param y: Values to transform
137147
:type y: array or None
138-
:return: Array of transformed values or Data object, depending on input
139-
:rtype: array or Data object
140148
"""
141149

142-
if not isinstance(X, Data) and y is not None:
143-
data = None
144-
nuclear_charges = X
145-
else:
146-
data = self._preprocess_input(X)
147-
nuclear_charges = data.nuclear_charges[data._indices]
148-
y = data.energies[data._indices]
150+
self._fit(X, y)
151+
152+
return self
153+
154+
def _fit(self, X, y=None):
155+
"""
156+
Does the work of the fit method, but returns variables
157+
that the fit_transform method can reuse.
158+
"""
159+
160+
data, nuclear_charges, y = self._parse_input(X, y)
149161

150162
if self.elements == 'auto':
151163
self.elements = get_unique(nuclear_charges)
152164
else:
153165
self._check_elements(nuclear_charges)
154166

155-
156167
features = self._featurizer(nuclear_charges)
157168

158-
delta_y = y - self.model.fit(features, y).predict(features)
169+
self.model.fit(features, y)
170+
171+
return data, features, y
172+
173+
def fit_transform(self, X, y=None):
174+
"""
175+
Fit and transform the data with a linear model.
176+
Supports three different types of input.
177+
1) X is a list of nuclear charges and y is values to transform.
178+
2) X is an array of indices of which to transform.
179+
3) X is a data object
180+
181+
:param X: List with nuclear charges or Data object.
182+
:type X: list
183+
:param y: Values to transform
184+
:type y: array or None
185+
:return: Array of transformed values or Data object, depending on input
186+
:rtype: array or Data object
187+
"""
188+
189+
data, features, y = self._fit(X, y)
190+
191+
return self._transform(data, features, y)
192+
193+
def _transform(self, data, features, y):
194+
"""
195+
Does the work of the transform method, but can be reused by
196+
the fit_transform method.
197+
"""
198+
199+
delta_y = y - self.model.predict(features)
159200

160201
if data:
161202
# Force copy
@@ -213,25 +254,10 @@ def transform(self, X, y=None):
213254
:rtype: array or Data object
214255
"""
215256

216-
if not isinstance(X, Data) and y is not None:
217-
data = None
218-
nuclear_charges = X
219-
else:
220-
data = self._preprocess_input(X)
221-
nuclear_charges = data.nuclear_charges[data._indices]
222-
y = data.energies[data._indices]
257+
data, nuclear_charges, y = self._parse_input(X, y)
223258

224259
self._check_elements(nuclear_charges)
225260

226261
features = self._featurizer(nuclear_charges)
227262

228-
delta_y = y - self.model.predict(features)
229-
230-
if data:
231-
# Force copy
232-
data.energies = data.energies.copy()
233-
data.energies[data._indices] = delta_y
234-
return data
235-
else:
236-
return delta_y
237-
263+
return self._transform(data, features, y)

0 commit comments

Comments
 (0)