26
26
27
27
import numpy as np
28
28
from sklearn .base import BaseEstimator
29
- from sklearn .linear_model import LinearRegression
29
+ from sklearn .linear_model import Lasso
30
30
31
31
from .data import Data
32
32
from ..utils import is_numeric_array , get_unique , is_positive_integer_or_zero_array
@@ -50,7 +50,7 @@ def __init__(self, data=None, elements='auto'):
50
50
self .elements = elements
51
51
52
52
# Initialize model
53
- self .model = LinearRegression ( )
53
+ self .model = Lasso ( alpha = 1e-9 )
54
54
55
55
def _preprocess_input (self , X ):
56
56
"""
@@ -117,15 +117,25 @@ def _check_data(self, X):
117
117
print ("Error: Expected Data object to have non-empty attribute 'energies'" % self .__class__ .__name__ )
118
118
raise SystemExit
119
119
120
-
121
120
def _set_data (self , data ):
122
121
if data :
123
122
self ._check_data (data )
124
123
self .data = data
125
124
126
- def fit_transform (self , X , y = None ):
125
+ def _parse_input (self , X , y ):
126
+ if not isinstance (X , Data ) and y is not None :
127
+ data = None
128
+ nuclear_charges = X
129
+ else :
130
+ data = self ._preprocess_input (X )
131
+ nuclear_charges = data .nuclear_charges [data ._indices ]
132
+ y = data .energies [data ._indices ]
133
+
134
+ return data , nuclear_charges , y
135
+
136
+ def fit (self , X , y = None ):
127
137
"""
128
- Fit and transform the data with a linear model.
138
+ Fit the data with a linear model.
129
139
Supports three different types of input.
130
140
1) X is a list of nuclear charges and y is values to transform.
131
141
2) X is an array of indices of which to transform.
@@ -135,27 +145,58 @@ def fit_transform(self, X, y=None):
135
145
:type X: list
136
146
:param y: Values to transform
137
147
:type y: array or None
138
- :return: Array of transformed values or Data object, depending on input
139
- :rtype: array or Data object
140
148
"""
141
149
142
- if not isinstance (X , Data ) and y is not None :
143
- data = None
144
- nuclear_charges = X
145
- else :
146
- data = self ._preprocess_input (X )
147
- nuclear_charges = data .nuclear_charges [data ._indices ]
148
- y = data .energies [data ._indices ]
150
+ self ._fit (X , y )
151
+
152
+ return self
153
+
154
+ def _fit (self , X , y = None ):
155
+ """
156
+ Does the work of the fit method, but returns variables
157
+ that the fit_transform method can reuse.
158
+ """
159
+
160
+ data , nuclear_charges , y = self ._parse_input (X , y )
149
161
150
162
if self .elements == 'auto' :
151
163
self .elements = get_unique (nuclear_charges )
152
164
else :
153
165
self ._check_elements (nuclear_charges )
154
166
155
-
156
167
features = self ._featurizer (nuclear_charges )
157
168
158
- delta_y = y - self .model .fit (features , y ).predict (features )
169
+ self .model .fit (features , y )
170
+
171
+ return data , features , y
172
+
173
+ def fit_transform (self , X , y = None ):
174
+ """
175
+ Fit and transform the data with a linear model.
176
+ Supports three different types of input.
177
+ 1) X is a list of nuclear charges and y is values to transform.
178
+ 2) X is an array of indices of which to transform.
179
+ 3) X is a data object
180
+
181
+ :param X: List with nuclear charges or Data object.
182
+ :type X: list
183
+ :param y: Values to transform
184
+ :type y: array or None
185
+ :return: Array of transformed values or Data object, depending on input
186
+ :rtype: array or Data object
187
+ """
188
+
189
+ data , features , y = self ._fit (X , y )
190
+
191
+ return self ._transform (data , features , y )
192
+
193
+ def _transform (self , data , features , y ):
194
+ """
195
+ Does the work of the transform method, but can be reused by
196
+ the fit_transform method.
197
+ """
198
+
199
+ delta_y = y - self .model .predict (features )
159
200
160
201
if data :
161
202
# Force copy
@@ -213,25 +254,10 @@ def transform(self, X, y=None):
213
254
:rtype: array or Data object
214
255
"""
215
256
216
- if not isinstance (X , Data ) and y is not None :
217
- data = None
218
- nuclear_charges = X
219
- else :
220
- data = self ._preprocess_input (X )
221
- nuclear_charges = data .nuclear_charges [data ._indices ]
222
- y = data .energies [data ._indices ]
257
+ data , nuclear_charges , y = self ._parse_input (X , y )
223
258
224
259
self ._check_elements (nuclear_charges )
225
260
226
261
features = self ._featurizer (nuclear_charges )
227
262
228
- delta_y = y - self .model .predict (features )
229
-
230
- if data :
231
- # Force copy
232
- data .energies = data .energies .copy ()
233
- data .energies [data ._indices ] = delta_y
234
- return data
235
- else :
236
- return delta_y
237
-
263
+ return self ._transform (data , features , y )
0 commit comments