diff --git a/qml/qmlearn/preprocessing.py b/qml/qmlearn/preprocessing.py index cab329075..2a28bb4a5 100644 --- a/qml/qmlearn/preprocessing.py +++ b/qml/qmlearn/preprocessing.py @@ -26,7 +26,7 @@ import numpy as np from sklearn.base import BaseEstimator -from sklearn.linear_model import LinearRegression +from sklearn.linear_model import Lasso from .data import Data from ..utils import is_numeric_array, get_unique, is_positive_integer_or_zero_array @@ -50,7 +50,7 @@ def __init__(self, data=None, elements='auto'): self.elements = elements # Initialize model - self.model = LinearRegression() + self.model = Lasso(alpha=1e-9) def _preprocess_input(self, X): """ @@ -117,15 +117,25 @@ def _check_data(self, X): print("Error: Expected Data object to have non-empty attribute 'energies'" % self.__class__.__name__) raise SystemExit - def _set_data(self, data): if data: self._check_data(data) self.data = data - def fit_transform(self, X, y=None): + def _parse_input(self, X, y): + if not isinstance(X, Data) and y is not None: + data = None + nuclear_charges = X + else: + data = self._preprocess_input(X) + nuclear_charges = data.nuclear_charges[data._indices] + y = data.energies[data._indices] + + return data, nuclear_charges, y + + def fit(self, X, y=None): """ - Fit and transform the data with a linear model. + Fit the data with a linear model. Supports three different types of input. 1) X is a list of nuclear charges and y is values to transform. 2) X is an array of indices of which to transform. @@ -135,27 +145,58 @@ def fit_transform(self, X, y=None): :type X: list :param y: Values to transform :type y: array or None - :return: Array of transformed values or Data object, depending on input - :rtype: array or Data object """ - if not isinstance(X, Data) and y is not None: - data = None - nuclear_charges = X - else: - data = self._preprocess_input(X) - nuclear_charges = data.nuclear_charges[data._indices] - y = data.energies[data._indices] + self._fit(X, y) + + return self + + def _fit(self, X, y=None): + """ + Does the work of the fit method, but returns variables + that the fit_transform method can reuse. + """ + + data, nuclear_charges, y = self._parse_input(X, y) if self.elements == 'auto': self.elements = get_unique(nuclear_charges) else: self._check_elements(nuclear_charges) - features = self._featurizer(nuclear_charges) - delta_y = y - self.model.fit(features, y).predict(features) + self.model.fit(features, y) + + return data, features, y + + def fit_transform(self, X, y=None): + """ + Fit and transform the data with a linear model. + Supports three different types of input. + 1) X is a list of nuclear charges and y is values to transform. + 2) X is an array of indices of which to transform. + 3) X is a data object + + :param X: List with nuclear charges or Data object. + :type X: list + :param y: Values to transform + :type y: array or None + :return: Array of transformed values or Data object, depending on input + :rtype: array or Data object + """ + + data, features, y = self._fit(X, y) + + return self._transform(data, features, y) + + def _transform(self, data, features, y): + """ + Does the work of the transform method, but can be reused by + the fit_transform method. + """ + + delta_y = y - self.model.predict(features) if data: # Force copy @@ -213,25 +254,10 @@ def transform(self, X, y=None): :rtype: array or Data object """ - if not isinstance(X, Data) and y is not None: - data = None - nuclear_charges = X - else: - data = self._preprocess_input(X) - nuclear_charges = data.nuclear_charges[data._indices] - y = data.energies[data._indices] + data, nuclear_charges, y = self._parse_input(X, y) self._check_elements(nuclear_charges) features = self._featurizer(nuclear_charges) - delta_y = y - self.model.predict(features) - - if data: - # Force copy - data.energies = data.energies.copy() - data.energies[data._indices] = delta_y - return data - else: - return delta_y - + return self._transform(data, features, y)