-
Notifications
You must be signed in to change notification settings - Fork 132
/
Copy pathode.py
345 lines (279 loc) · 14.5 KB
/
ode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains ODE solvers, both sequential as well as time-parallel multiple shooting methods necessary for multiple-shoting layers [1].
The stateful design allows users to modify or tweak each Tableau during training, ensuring compatibility with hybrid methods such as Hypersolvers [2]
[1]: Massaroli S., Poli M. et al "Differentiable Multiple Shooting Layers."
[2]: Poli M., Massaroli S. et al "Hypersolvers: Toward fast continuous-depth models." NeurIPS 2020
"""
from typing import Tuple
import torch
import torch.nn as nn
from torchdyn.numerics.solvers.templates import DiffEqSolver, MultipleShootingDiffeqSolver
from torchdyn.numerics.solvers._constants import construct_rk4, construct_dopri5, construct_tsit5
class SolverTemplate(nn.Module):
def __init__(self, order, min_factor:float=0.2, max_factor:float=10, safety:float=0.9):
super().__init__()
self.order = order
self.min_factor = torch.tensor([min_factor])
self.max_factor = torch.tensor([max_factor])
self.safety = torch.tensor([safety])
self.tableau = None
def sync_device_dtype(self, x, t_span):
"Ensures `x`, `t_span`, `tableau` and other solver tensors are on the same device with compatible dtypes"
if isinstance(x, dict):
proto_arr = x[list(x.keys())[0]]
elif isinstance(x, torch.Tensor):
proto_arr = x
else:
raise NotImplementedError(f"{type(x)} is not supported as the state variable")
device = proto_arr.device
if self.tableau is not None:
c, a, bsol, berr = self.tableau
self.tableau = c.to(proto_arr), [a.to(proto_arr) for a in a], bsol.to(proto_arr), berr.to(proto_arr)
t_span = t_span.to(device)
self.safety = self.safety.to(device)
self.min_factor = self.min_factor.to(device)
self.max_factor = self.max_factor.to(device)
return x, t_span
def step(self, f, x, t, dt, k1=None, args=None):
pass
class Euler(SolverTemplate):
def __init__(self, dtype=torch.float32):
"""Explicit Euler ODE stepper, order 1"""
super().__init__(order=1)
self.dtype = dtype
self.stepping_class = 'fixed'
def step(self, f, x, t, dt, k1=None, args=None):
if k1 == None: k1 = f(t, x)
x_sol = x + dt * k1
return None, x_sol, None
class Midpoint(DiffEqSolver):
def __init__(self, dtype=torch.float32):
"""Explicit Midpoint ODE stepper, order 2"""
super().__init__(order=2)
self.dtype = dtype
self.stepping_class = 'fixed'
def step(self, f, x, t, dt, k1=None, args=None):
if k1 == None: k1 = f(t, x)
x_mid = x + 0.5 * dt * k1
x_sol = x + dt * f(t + 0.5 * dt, x_mid)
return None, x_sol, None
class RungeKutta4(DiffEqSolver):
def __init__(self, dtype=torch.float32):
"""Explicit Midpoint ODE stepper, order 4"""
super().__init__(order=4)
self.dtype = dtype
self.stepping_class = 'fixed'
self.tableau = construct_rk4(self.dtype)
def step(self, f, x, t, dt, k1=None, args=None):
c, a, bsol, _ = self.tableau
if k1 == None: k1 = f(t, x)
k2 = f(t + c[0] * dt, x + dt * (a[0] * k1))
k3 = f(t + c[1] * dt, x + dt * (a[1][0] * k1 + a[1][1] * k2))
k4 = f(t + c[2] * dt, x + dt * (a[2][0] * k1 + a[2][1] * k2 + a[2][2] * k3))
x_sol = x + dt * (bsol[0] * k1 + bsol[1] * k2 + bsol[2] * k3 + bsol[3] * k4)
return None, x_sol, None
class AsynchronousLeapfrog(DiffEqSolver):
def __init__(self, channel_index:int=-1, stepping_class:str='fixed', dtype=torch.float32):
"""Explicit Leapfrog symplectic ODE stepper.
Can return local error estimates if adaptive stepping is required"""
super().__init__(order=2)
self.dtype = dtype
self.channel_index = channel_index
self.stepping_class = stepping_class
self.const = 1
self.tableau = construct_rk4(self.dtype)
# an additional overhead, necessary to preserve a certain degree of sanity
# in the implementation and to avoid API bloating.
self.x_shape = None
def step(self, f, xv, t, dt, k1=None, args=None):
half_state_dim = xv.shape[-1] // 2
x, v = xv[..., :half_state_dim], xv[..., half_state_dim:]
if k1 == None: k1 = f(t, x)
x1 = x + 0.5 * dt * v
vt1 = f(t + 0.5 * dt, x1)
v1 = 2 * self.const * (vt1 - v) + v
x2 = x1 + 0.5 * dt * v1
x_sol = torch.cat([x2, v1], -1)
if self.stepping_class == 'adaptive':
xv_err = torch.cat([torch.zeros_like(x), v], -1)
else:
xv_err = None
return None, x_sol, xv_err
class DormandPrince45(DiffEqSolver):
def __init__(self, dtype=torch.float32):
super().__init__(order=5)
self.dtype = dtype
self.stepping_class = 'adaptive'
self.tableau = construct_dopri5(self.dtype)
def step(self, f, x, t, dt, k1=None, args=None) -> Tuple:
c, a, bsol, berr = self.tableau
if k1 is None: k1 = f(t, x)
k2 = f(t + c[0] * dt, x + dt * a[0] * k1)
k3 = f(t + c[1] * dt, x + dt * (a[1][0] * k1 + a[1][1] * k2))
k4 = f(t + c[2] * dt, x + dt * a[2][0] * k1 + dt * a[2][1] * k2 + dt * a[2][2] * k3)
k5 = f(t + c[3] * dt, x + dt * a[3][0] * k1 + dt * a[3][1] * k2 + dt * a[3][2] * k3 + dt * a[3][3] * k4)
k6 = f(t + c[4] * dt, x + dt * a[4][0] * k1 + dt * a[4][1] * k2 + dt * a[4][2] * k3 + dt * a[4][3] * k4 + dt * a[4][4] * k5)
k7 = f(t + c[5] * dt, x + dt * a[5][0] * k1 + dt * a[5][1] * k2 + dt * a[5][2] * k3 + dt * a[5][3] * k4 + dt * a[5][4] * k5 + dt * a[5][5] * k6)
x_sol = x + dt * (bsol[0] * k1 + bsol[1] * k2 + bsol[2] * k3 + bsol[3] * k4 + bsol[4] * k5 + bsol[5] * k6)
err = dt * (berr[0] * k1 + berr[1] * k2 + berr[2] * k3 + berr[3] * k4 + berr[4] * k5 + berr[5] * k6 + berr[6] * k7)
return k7, x_sol, err, (k1, k2, k3, k4, k5, k6, k7)
class Tsitouras45(DiffEqSolver):
def __init__(self, dtype=torch.float32):
super().__init__(order=5)
self.dtype = dtype
self.stepping_class = 'adaptive'
self.tableau = construct_tsit5(self.dtype)
def step(self, f, x, t, dt, k1=None, args=None) -> Tuple:
c, a, bsol, berr = self.tableau
if k1 == None: k1 = f(t, x)
k2 = f(t + c[0] * dt, x + dt * a[0][0] * k1)
k3 = f(t + c[1] * dt, x + dt * (a[1][0] * k1 + a[1][1] * k2))
k4 = f(t + c[2] * dt, x + dt * a[2][0] * k1 + dt * a[2][1] * k2 + dt * a[2][2] * k3)
k5 = f(t + c[3] * dt, x + dt * a[3][0] * k1 + dt * a[3][1] * k2 + dt * a[3][2] * k3 + dt * a[3][3] * k4)
k6 = f(t + c[4] * dt, x + dt * a[4][0] * k1 + dt * a[4][1] * k2 + dt * a[4][2] * k3 + dt * a[4][3] * k4 + dt * a[4][4] * k5)
k7 = f(t + c[5] * dt, x + dt * a[5][0] * k1 + dt * a[5][1] * k2 + dt * a[5][2] * k3 + dt * a[5][3] * k4 + dt * a[5][4] * k5 + dt * a[5][5] * k6)
x_sol = x + dt * (bsol[0] * k1 + bsol[1] * k2 + bsol[2] * k3 + bsol[3] * k4 + bsol[4] * k5 + bsol[5] * k6)
err = dt * (berr[0] * k1 + berr[1] * k2 + berr[2] * k3 + berr[3] * k4 + berr[4] * k5 + berr[5] * k6 + berr[6] * k7)
return k7, x_sol, err, (k1, k2, k3, k4, k5, k6, k7)
class ImplicitEuler(DiffEqSolver):
def __init__(self, dtype=torch.float32):
super().__init__(order=1)
self.dtype = dtype
self.stepping_class = 'fixed'
self.opt = torch.optim.LBFGS
self.max_iters = 200
@staticmethod
def _residual(f, x, t, dt, x_sol):
f_sol = f(t, x_sol)
return torch.sum((x_sol - x - dt*f_sol)**2)
def step(self, f, x, t, dt, k1=None, args=None):
x_sol = x.clone()
x_sol = nn.Parameter(data=x_sol)
opt = self.opt([x_sol], lr=1, max_iter=self.max_iters, max_eval=10*self.max_iters,
tolerance_grad=1.e-12, tolerance_change=1.e-12, history_size=100, line_search_fn='strong_wolfe')
def closure():
opt.zero_grad()
residual = ImplicitEuler._residual(f, x, t, dt, x_sol)
x_sol.grad, = torch.autograd.grad(residual, x_sol, only_inputs=True, allow_unused=False)
return residual
opt.step(closure)
return None, x_sol, None
class MSForward(MultipleShootingDiffeqSolver):
"""Multiple shooting solver using forward sensitivity analysis on the matching conditions of shooting parameters"""
def __init__(self, coarse_method='euler', fine_method='rk4'):
super().__init__(coarse_method, fine_method)
def root_solve(self, f, x, t_span, B):
raise NotImplementedError("Waiting for `functorch` to be merged in the stable version of Pytorch"
"we need their vjp for efficient implementation of forward sensitivity"
"Refer to DiffEqML/diffeqml-research/multiple-shooting-layers for a manual implementation")
class MSZero(MultipleShootingDiffeqSolver):
def __init__(self, coarse_method='euler', fine_method='rk4'):
"""Multiple shooting solver using Parareal updates (zero-order approximation of the Jacobian)
Args:
coarse_method (str, optional): . Defaults to 'euler'.
fine_method (str, optional): . Defaults to 'rk4'.
"""
super().__init__(coarse_method, fine_method)
# TODO (qol): extend to time-variant ODEs by using shifted_odeint
def root_solve(self, odeint_func, f, x, t_span, B, fine_steps, maxiter):
dt, n_subinterv = t_span[1] - t_span[0], len(t_span)
sub_t_span = torch.linspace(0, dt, fine_steps).to(x)
i = 0
while i <= maxiter:
i += 1
B_coarse = odeint_func(f, B[i-1:], sub_t_span, solver=self.coarse_method)[1][-1]
B_fine = odeint_func(f, B[i-1:], sub_t_span, solver=self.fine_method)[1][-1]
B_out = torch.zeros_like(B)
B_out[:i] = B[:i]
B_in = B[i-1]
for m in range(i, n_subinterv):
B_in = odeint_func(f, B_in, sub_t_span, solver=self.coarse_method)[1][-1]
B_in = B_in - B_coarse[m-i] + B_fine[m-i]
B_out[m] = B_in
B = B_out
return B
class MSBackward(MultipleShootingDiffeqSolver):
def __init__(self, coarse_method='euler', fine_method='rk4'):
"""Multiple shooting solver using discrete adjoints for the Jacobian
Args:
coarse_method (str, optional): . Defaults to 'euler'.
fine_method (str, optional): . Defaults to 'rk4'.
"""
super().__init__(coarse_method, fine_method)
def root_solve(self, odeint_func, f, x, t_span, B, fine_steps, maxiter):
dt, n_subinterv = t_span[1] - t_span[0], len(t_span)
sub_t_span = torch.linspace(0, dt, fine_steps).to(x)
i = 0
B = B.requires_grad_(True)
while i <= maxiter:
i += 1
B_fine = odeint_func(f, B[i-1:], sub_t_span, solver=self.fine_method)[1][-1]
B_out = torch.zeros_like(B)
B_out[:i] = B[:i]
B_in = B[i-1]
for m in range(i, n_subinterv):
# instead of jvps here the full jacobian can be computed and the vector products
# which involve `B_in` can be performed. Trading memory ++ for speed ++
J_blk = torch.autograd.grad(B_fine[m-1], B, B_in - B[m-1], retain_graph=True)[0][m-1]
B_in = B_fine[m-1] + J_blk
B_out[m] = B_in
del B # manually free graph
B = B_out
return B
class ParallelImplicitEuler(MultipleShootingDiffeqSolver):
def __init__(self, coarse_method='euler', fine_method='euler'):
"""Parallel Implicit Euler Method"""
super().__init__(coarse_method, fine_method)
self.solver = torch.optim.LBFGS
self.max_iters = 200
def sync_device_dtype(self, x, t_span):
return x, t_span
@staticmethod
def _residual(f, x, B, t_span):
dt = t_span[1:] - t_span[:-1]
F = f(0., B[1:])
residual = torch.sum((B[2:] - B[1:-1] - dt[1:, None, None] * F[1:]) ** 2)
residual += torch.sum((B[1] - x - dt[0] * F[0]) ** 2)
return residual
# TODO (qol): extend to time-variant ODEs by model parallelization
def root_solve(self, odeint_func, f, x, t_span, B, fine_steps, maxiter):
B = B.clone()
B = nn.Parameter(data=B)
solver = self.solver([B], lr=1, max_iter=self.max_iters, max_eval=10 * self.max_iters,
tolerance_grad=1.e-12, tolerance_change=1.e-12, history_size=100,
line_search_fn='strong_wolfe')
def closure():
solver.zero_grad()
residual = ParallelImplicitEuler._residual(f, x, B, t_span)
B.grad, = torch.autograd.grad(residual, B, only_inputs=True, allow_unused=False)
return residual
solver.step(closure)
return B
SOLVER_DICT = {'euler': Euler, 'midpoint': Midpoint,
'rk4': RungeKutta4, 'rk-4': RungeKutta4, 'RungeKutta4': RungeKutta4,
'dopri5': DormandPrince45, 'DormandPrince45': DormandPrince45, 'DormandPrince5': DormandPrince45,
'tsit5': Tsitouras45, 'Tsitouras45': Tsitouras45, 'Tsitouras5': Tsitouras45,
'ieuler': ImplicitEuler, 'implicit_euler': ImplicitEuler,
'alf': AsynchronousLeapfrog, 'AsynchronousLeapfrog': AsynchronousLeapfrog}
MS_SOLVER_DICT = {'mszero': MSZero, 'zero': MSZero, 'parareal': MSZero,
'msbackward': MSBackward, 'backward': MSBackward, 'discrete-adjoint': MSBackward,
'ieuler': ParallelImplicitEuler, 'parallel-implicit-euler': ParallelImplicitEuler}
def str_to_solver(solver_name, dtype=torch.float32):
"Transforms string specifying desired solver into an instance of the Solver class."
solver = SOLVER_DICT[solver_name]
return solver(dtype)
def str_to_ms_solver(solver_name, dtype=torch.float32):
"Returns MSSolver class corresponding to a given string."
solver = MS_SOLVER_DICT[solver_name]
return solver()