Skip to content

Commit b0b2b73

Browse files
authored
support generalized force loss (deepmodeling#2690)
Support the loss for generalized forces. Tests and examples have been added. Generalized forces are given by ```math Q_j = \sum_{i=1}^n \mathbf F_i \cdot \frac {\partial \mathbf r_i} {\partial q_j},\quad j=1,\ldots, m. ``` The loss for generalized forces is given by ```math L_Q = \frac{1}{m} \sum_{j=1}^m (Q_j - Q_j^*)^2. ``` In the example, the generalized coordinates $q$ are the restraint coordinates in the enhanced sampling. This PR also improves documentation for other arguments in the loss. --------- Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 2ca2f9c commit b0b2b73

File tree

8 files changed

+449
-13
lines changed

8 files changed

+449
-13
lines changed

deepmd/loss/ener.py

+103-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,42 @@ class EnerStdLoss(Loss):
2727
2828
Parameters
2929
----------
30+
starter_learning_rate : float
31+
The learning rate at the start of the training.
32+
start_pref_e : float
33+
The prefactor of energy loss at the start of the training.
34+
limit_pref_e : float
35+
The prefactor of energy loss at the end of the training.
36+
start_pref_f : float
37+
The prefactor of force loss at the start of the training.
38+
limit_pref_f : float
39+
The prefactor of force loss at the end of the training.
40+
start_pref_v : float
41+
The prefactor of virial loss at the start of the training.
42+
limit_pref_v : float
43+
The prefactor of virial loss at the end of the training.
44+
start_pref_ae : float
45+
The prefactor of atomic energy loss at the start of the training.
46+
limit_pref_ae : float
47+
The prefactor of atomic energy loss at the end of the training.
48+
start_pref_pf : float
49+
The prefactor of atomic prefactor force loss at the start of the training.
50+
limit_pref_pf : float
51+
The prefactor of atomic prefactor force loss at the end of the training.
52+
relative_f : float
53+
If provided, relative force error will be used in the loss. The difference
54+
of force will be normalized by the magnitude of the force in the label with
55+
a shift given by relative_f
3056
enable_atom_ener_coeff : bool
3157
if true, the energy will be computed as \sum_i c_i E_i
58+
start_pref_gf : float
59+
The prefactor of generalized force loss at the start of the training.
60+
limit_pref_gf : float
61+
The prefactor of generalized force loss at the end of the training.
62+
numb_generalized_coord : int
63+
The dimension of generalized coordinates.
64+
**kwargs
65+
Other keyword arguments.
3266
"""
3367

3468
def __init__(
@@ -46,6 +80,9 @@ def __init__(
4680
limit_pref_pf: float = 0.0,
4781
relative_f: Optional[float] = None,
4882
enable_atom_ener_coeff: bool = False,
83+
start_pref_gf: float = 0.0,
84+
limit_pref_gf: float = 0.0,
85+
numb_generalized_coord: int = 0,
4986
**kwargs,
5087
) -> None:
5188
self.starter_learning_rate = starter_learning_rate
@@ -61,11 +98,19 @@ def __init__(
6198
self.limit_pref_pf = limit_pref_pf
6299
self.relative_f = relative_f
63100
self.enable_atom_ener_coeff = enable_atom_ener_coeff
101+
self.start_pref_gf = start_pref_gf
102+
self.limit_pref_gf = limit_pref_gf
103+
self.numb_generalized_coord = numb_generalized_coord
64104
self.has_e = self.start_pref_e != 0.0 or self.limit_pref_e != 0.0
65105
self.has_f = self.start_pref_f != 0.0 or self.limit_pref_f != 0.0
66106
self.has_v = self.start_pref_v != 0.0 or self.limit_pref_v != 0.0
67107
self.has_ae = self.start_pref_ae != 0.0 or self.limit_pref_ae != 0.0
68108
self.has_pf = self.start_pref_pf != 0.0 or self.limit_pref_pf != 0.0
109+
self.has_gf = self.start_pref_gf != 0.0 or self.limit_pref_gf != 0.0
110+
if self.has_gf and self.numb_generalized_coord < 1:
111+
raise RuntimeError(
112+
"When generalized force loss is used, the dimension of generalized coordinates should be larger than 0"
113+
)
69114
# data required
70115
add_data_requirement("energy", 1, atomic=False, must=False, high_prec=True)
71116
add_data_requirement("force", 3, atomic=True, must=False, high_prec=False)
@@ -74,6 +119,15 @@ def __init__(
74119
add_data_requirement(
75120
"atom_pref", 1, atomic=True, must=False, high_prec=False, repeat=3
76121
)
122+
# drdq: the partial derivative of atomic coordinates w.r.t. generalized coordinates
123+
# TODO: could numb_generalized_coord decided from the training data?
124+
add_data_requirement(
125+
"drdq",
126+
self.numb_generalized_coord * 3,
127+
atomic=True,
128+
must=False,
129+
high_prec=False,
130+
)
77131
if self.enable_atom_ener_coeff:
78132
add_data_requirement(
79133
"atom_ener_coeff",
@@ -99,6 +153,9 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
99153
find_virial = label_dict["find_virial"]
100154
find_atom_ener = label_dict["find_atom_ener"]
101155
find_atom_pref = label_dict["find_atom_pref"]
156+
if self.has_gf:
157+
drdq = label_dict["drdq"]
158+
find_drdq = label_dict["find_drdq"]
102159

103160
if self.enable_atom_ener_coeff:
104161
# when ener_coeff (\nu) is defined, the energy is defined as
@@ -117,7 +174,7 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
117174
tf.square(energy - energy_hat), name="l2_" + suffix
118175
)
119176

120-
if self.has_f or self.has_pf or self.relative_f:
177+
if self.has_f or self.has_pf or self.relative_f or self.has_gf:
121178
force_reshape = tf.reshape(force, [-1])
122179
force_hat_reshape = tf.reshape(force_hat, [-1])
123180
diff_f = force_hat_reshape - force_reshape
@@ -139,6 +196,22 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
139196
name="l2_pref_force_" + suffix,
140197
)
141198

199+
if self.has_gf:
200+
drdq = label_dict["drdq"]
201+
force_reshape_nframes = tf.reshape(force, [-1, natoms[0] * 3])
202+
force_hat_reshape_nframes = tf.reshape(force_hat, [-1, natoms[0] * 3])
203+
drdq_reshape = tf.reshape(
204+
drdq, [-1, natoms[0] * 3, self.numb_generalized_coord]
205+
)
206+
gen_force_hat = tf.einsum(
207+
"bij,bi->bj", drdq_reshape, force_hat_reshape_nframes
208+
)
209+
gen_force = tf.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes)
210+
diff_gen_force = gen_force_hat - gen_force
211+
l2_gen_force_loss = tf.reduce_mean(
212+
tf.square(diff_gen_force), name="l2_gen_force_" + suffix
213+
)
214+
142215
if self.has_v:
143216
virial_reshape = tf.reshape(virial, [-1])
144217
virial_hat_reshape = tf.reshape(virial_hat, [-1])
@@ -202,6 +275,16 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
202275
/ self.starter_learning_rate
203276
)
204277
)
278+
if self.has_gf:
279+
pref_gf = global_cvt_2_tf_float(
280+
find_drdq
281+
* (
282+
self.limit_pref_gf
283+
+ (self.start_pref_gf - self.limit_pref_gf)
284+
* learning_rate
285+
/ self.starter_learning_rate
286+
)
287+
)
205288

206289
l2_loss = 0
207290
more_loss = {}
@@ -220,6 +303,9 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
220303
if self.has_pf:
221304
l2_loss += global_cvt_2_ener_float(pref_pf * l2_pref_force_loss)
222305
more_loss["l2_pref_force_loss"] = l2_pref_force_loss
306+
if self.has_gf:
307+
l2_loss += global_cvt_2_ener_float(pref_gf * l2_gen_force_loss)
308+
more_loss["l2_gen_force_loss"] = l2_gen_force_loss
223309

224310
# only used when tensorboard was set as true
225311
self.l2_loss_summary = tf.summary.scalar("l2_loss_" + suffix, tf.sqrt(l2_loss))
@@ -238,6 +324,18 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
238324
"l2_virial_loss_" + suffix,
239325
tf.sqrt(l2_virial_loss) / global_cvt_2_tf_float(natoms[0]),
240326
)
327+
if self.has_ae:
328+
self.l2_loss_atom_ener_summary = tf.summary.scalar(
329+
"l2_atom_ener_loss_" + suffix, tf.sqrt(l2_atom_ener_loss)
330+
)
331+
if self.has_pf:
332+
self.l2_loss_pref_force_summary = tf.summary.scalar(
333+
"l2_pref_force_loss_" + suffix, tf.sqrt(l2_pref_force_loss)
334+
)
335+
if self.has_gf:
336+
self.l2_loss_gf_summary = tf.summary.scalar(
337+
"l2_gen_force_loss_" + suffix, tf.sqrt(l2_gen_force_loss)
338+
)
241339

242340
self.l2_l = l2_loss
243341
self.l2_more = more_loss
@@ -252,8 +350,9 @@ def eval(self, sess, feed_dict, natoms):
252350
self.l2_more["l2_virial_loss"] if self.has_v else placeholder,
253351
self.l2_more["l2_atom_ener_loss"] if self.has_ae else placeholder,
254352
self.l2_more["l2_pref_force_loss"] if self.has_pf else placeholder,
353+
self.l2_more["l2_gen_force_loss"] if self.has_gf else placeholder,
255354
]
256-
error, error_e, error_f, error_v, error_ae, error_pf = run_sess(
355+
error, error_e, error_f, error_v, error_ae, error_pf, error_gf = run_sess(
257356
sess, run_data, feed_dict=feed_dict
258357
)
259358
results = {"natoms": natoms[0], "rmse": np.sqrt(error)}
@@ -267,6 +366,8 @@ def eval(self, sess, feed_dict, natoms):
267366
results["rmse_v"] = np.sqrt(error_v) / natoms[0]
268367
if self.has_pf:
269368
results["rmse_pf"] = np.sqrt(error_pf)
369+
if self.has_gf:
370+
results["rmse_gf"] = np.sqrt(error_gf)
270371
return results
271372

272373

deepmd/utils/argcheck.py

+40-10
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ def pairwise_dprc() -> Argument:
918918

919919
# --- Learning rate configurations: --- #
920920
def learning_rate_exp():
921-
doc_start_lr = "The learning rate the start of the training."
921+
doc_start_lr = "The learning rate at the start of the training."
922922
doc_stop_lr = "The desired learning rate at the end of the training."
923923
doc_decay_steps = (
924924
"The learning rate is decaying every this number of training steps."
@@ -977,25 +977,34 @@ def learning_rate_dict_args():
977977

978978

979979
# --- Loss configurations: --- #
980-
def start_pref(item):
981-
return f"The prefactor of {item} loss at the start of the training. Should be larger than or equal to 0. If set to none-zero value, the {item} label should be provided by file {item}.npy in each data system. If both start_pref_{item} and limit_pref_{item} are set to 0, then the {item} will be ignored."
980+
def start_pref(item, label=None, abbr=None):
981+
if label is None:
982+
label = item
983+
if abbr is None:
984+
abbr = item
985+
return f"The prefactor of {item} loss at the start of the training. Should be larger than or equal to 0. If set to none-zero value, the {label} label should be provided by file {label}.npy in each data system. If both start_pref_{abbr} and limit_pref_{abbr} are set to 0, then the {item} will be ignored."
982986

983987

984988
def limit_pref(item):
985989
return f"The prefactor of {item} loss at the limit of the training, Should be larger than or equal to 0. i.e. the training step goes to infinity."
986990

987991

988992
def loss_ener():
989-
doc_start_pref_e = start_pref("energy")
993+
doc_start_pref_e = start_pref("energy", abbr="e")
990994
doc_limit_pref_e = limit_pref("energy")
991-
doc_start_pref_f = start_pref("force")
995+
doc_start_pref_f = start_pref("force", abbr="f")
992996
doc_limit_pref_f = limit_pref("force")
993-
doc_start_pref_v = start_pref("virial")
997+
doc_start_pref_v = start_pref("virial", abbr="v")
994998
doc_limit_pref_v = limit_pref("virial")
995-
doc_start_pref_ae = start_pref("atom_ener")
996-
doc_limit_pref_ae = limit_pref("atom_ener")
997-
doc_start_pref_pf = start_pref("atom_pref")
998-
doc_limit_pref_pf = limit_pref("atom_pref")
999+
doc_start_pref_ae = start_pref("atomic energy", label="atom_ener", abbr="ae")
1000+
doc_limit_pref_ae = limit_pref("atomic energy")
1001+
doc_start_pref_pf = start_pref(
1002+
"atomic prefactor force", label="atom_pref", abbr="pf"
1003+
)
1004+
doc_limit_pref_pf = limit_pref("atomic prefactor force")
1005+
doc_start_pref_gf = start_pref("generalized force", label="drdq", abbr="gf")
1006+
doc_limit_pref_gf = limit_pref("generalized force")
1007+
doc_numb_generalized_coord = "The dimension of generalized coordinates. Required when generalized force loss is used."
9991008
doc_relative_f = "If provided, relative force error will be used in the loss. The difference of force will be normalized by the magnitude of the force in the label with a shift given by `relative_f`, i.e. DF_i / ( || F || + relative_f ) with DF denoting the difference between prediction and label and || F || denoting the L2 norm of the label."
10001009
doc_enable_atom_ener_coeff = "If true, the energy will be computed as \\sum_i c_i E_i. c_i should be provided by file atom_ener_coeff.npy in each data system, otherwise it's 1."
10011010
return [
@@ -1077,6 +1086,27 @@ def loss_ener():
10771086
default=False,
10781087
doc=doc_enable_atom_ener_coeff,
10791088
),
1089+
Argument(
1090+
"start_pref_gf",
1091+
float,
1092+
optional=True,
1093+
default=0.0,
1094+
doc=doc_start_pref_gf,
1095+
),
1096+
Argument(
1097+
"limit_pref_gf",
1098+
float,
1099+
optional=True,
1100+
default=0.0,
1101+
doc=doc_limit_pref_gf,
1102+
),
1103+
Argument(
1104+
"numb_generalized_coord",
1105+
int,
1106+
optional=True,
1107+
default=0,
1108+
doc=doc_numb_generalized_coord,
1109+
),
10801110
]
10811111

10821112

doc/data/system.md

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dipole | Frame dipole | dipole.raw | A
3333
atomic_dipole | Atomic dipole | atomic_dipole.raw | Any | Nframes \* Natoms \* 3 |
3434
polarizability | Frame polarizability | polarizability.raw | Any | Nframes \* 9 | in the order `XX XY XZ YX YY YZ ZX ZY ZZ`
3535
atomic_polarizability | Atomic polarizability | atomic_polarizability.raw| Any | Nframes \* Natoms \* 9 | in the order `XX XY XZ YX YY YZ ZX ZY ZZ`
36+
drdq | Partial derivative of atomic coordinates with respect to generalized coordinates | drdq.raw | 1 | Nframes \* Natoms \* 3 \* Ngen_coords |
3637

3738
In general, we always use the following convention of units:
3839

examples/dprc/data/set.000/drdq.npy

6.57 KB
Binary file not shown.

0 commit comments

Comments
 (0)