Skip to content
This repository was archived by the owner on Dec 8, 2024. It is now read-only.

Commit 98d8299

Browse files
Silvialarsbratholm
Silvia
authored andcommitted
Tolerance fix (#104)
* Corrected small bug in predict function * Started updating so that model can be trained after its been reloaded * Minor modifications * Updated model so one can predict from xyz and disabled shuffling in training because it leads to a problem with predictions * Fix for the problem of shuffling * Added some tests to make sure the predictions work * Fixed a tensorboard problem * The saving of the model doesn't cause an error if the directory already exists * Fixed a bug that made a test fail * Modified the name of a parameter * Made modifications to make te symmetry functions more numerically stable * Added a hack that makes ARMP work with fortran ACSF when there are padded representations. Currently works *ONLY* when there is one molecule for the whole data set. * corrected bug in score function for padded molecules * Changes that make the model work quickly even when there is padding. * Fixed discrepancies between fortran and TF acsf * Corrected bug in setting of ACSF parameters * Attempt at fixing issue #10 * another attempt at fixing #10 * Removed a pointless line * set-up * Added the graceful killer * Modifications which prevent installation from breaking on BC4 * Modification to add neural networks to qmlearn * Fix for issue #8 * Random comment * Started including the atomic model * Made the atomic neural network work * Fixed a bug with the indices * Now training and predictions don't use the default graph, to avoid problems * uncommented examples * Removed unique_elements in data class This can be stored in the NN class, but I might reverse the change later * Made tensorflow an optional dependency The reason for this approach is that pip would just auto install tensorflow and you might want the gpu version or your own compiled one. * Made is_numeric non-private and removed legacy code * Added 1d array util function * Removed QML check and moved functions from utils to tf_utils * Support for linear models (no hidden layers) * fixed import bug in tf_utils * Added text to explain that you are scoring on training set * Restructure. But elements are still not working Sorted elements * Moved documentation from init to class * Constant features will now be removed at fit/predict time * Moved get_batch_size back into utils, since it doesn't depend on tf * Made the NeuralNetwork class compliant with sklearn Cannot be any transforms of the input data * Fixed tests that didn't pass * Fixed mistake in checks of set_classes() in ARMP * started fixing ARMP bugs for QM7 * Fixed bug in padding and added examples that give low errors * Attempted fix to make representations single precision * Hot fix for AtomScaler * Minor bug fixes * More bug fixes to make sure tests run * Fixed some tests that had failures * Reverted the fchl tests to original * Fixed path in acsf test * Readded changes to tests * Modifications after code review * Version with the ACSF basis functions starting at 0.8 A * Updated ACSF representations so that the minimum distance at which to start the binning can be set by the user * Modified the name of the new parameter (minimum distance of the binning in ACSF) * Added a function to the atomscaler that enables to revert back * Relaxed tolerance in tests
1 parent 50cc6a7 commit 98d8299

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

qml/qmlearn/preprocessing.py

+39
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,21 @@ def _transform(self, data, features, y):
206206
else:
207207
return delta_y
208208

209+
def _revert_transform(self, data, features, y):
210+
"""
211+
Reverts the work of the transform method.
212+
"""
213+
214+
full_y = y + self.model.predict(features)
215+
216+
if data:
217+
# Force copy
218+
data.energies = data.energies.copy()
219+
data.energies[data._indices] = full_y
220+
return data
221+
else:
222+
return full_y
223+
209224
def _check_elements(self, nuclear_charges):
210225
"""
211226
Check that the elements in the given nuclear_charges was
@@ -261,3 +276,27 @@ def transform(self, X, y=None):
261276
features = self._featurizer(nuclear_charges)
262277

263278
return self._transform(data, features, y)
279+
280+
def revert_transform(self, X, y=None):
281+
"""
282+
Transforms data back to what it originally would have been if it hadn't been transformed with the fitted linear
283+
model. Supports three different types of input.
284+
1) X is a list of nuclear charges and y is values to transform.
285+
2) X is an array of indices of which to transform.
286+
3) X is a data object
287+
288+
:param X: List with nuclear charges or Data object.
289+
:type X: list
290+
:param y: Values to revert to before transform
291+
:type y: array or None
292+
:return: Array of untransformed values or Data object, depending on input
293+
:rtype: array or Data object
294+
"""
295+
296+
data, nuclear_charges, y = self._parse_input(X, y)
297+
298+
self._check_elements(nuclear_charges)
299+
300+
features = self._featurizer(nuclear_charges)
301+
302+
return self._revert_transform(data, features, y)

test/test_armp.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def test_predict_fromxyz():
228228
pred1 = estimator.predict(idx)
229229
pred2 = estimator.predict_from_xyz(xyz, zs)
230230

231-
assert np.all(np.isclose(pred1, pred2, rtol=1.e-6))
231+
assert np.all(np.isclose(pred1, pred2, rtol=1.e-5))
232232

233233
estimator.save_nn(save_dir="temp")
234234

@@ -243,11 +243,11 @@ def test_predict_fromxyz():
243243
pred3 = new_estimator.predict(idx)
244244
pred4 = new_estimator.predict_from_xyz(xyz, zs)
245245

246-
assert np.all(np.isclose(pred3, pred4, rtol=1.e-6))
247-
assert np.all(np.isclose(pred1, pred3, rtol=1.e-6))
248-
249246
shutil.rmtree("temp")
250247

248+
assert np.all(np.isclose(pred3, pred4, rtol=1.e-5))
249+
assert np.all(np.isclose(pred1, pred3, rtol=1.e-5))
250+
251251
def test_retraining():
252252
xyz = np.array([[[0, 1, 0], [0, 1, 1], [1, 0, 1]],
253253
[[1, 2, 2], [3, 1, 2], [1, 3, 4]],
@@ -291,8 +291,8 @@ def test_retraining():
291291

292292
pred4 = new_estimator.predict(idx)
293293

294-
assert np.all(np.isclose(pred1, pred3, rtol=1.e-6))
295-
assert np.all(np.isclose(pred2, pred4, rtol=1.e-6))
294+
assert np.all(np.isclose(pred1, pred3, rtol=1.e-5))
295+
assert np.all(np.isclose(pred2, pred4, rtol=1.e-5))
296296

297297
shutil.rmtree("temp")
298298

0 commit comments

Comments
 (0)