Skip to content

Commit 5c72d08

Browse files
committed
Duck typing and consistent return types (Vector{typeof(y0)}) for all solvers; see SciML#7. Tests adjusted accordingly.
1 parent c20b016 commit 5c72d08

File tree

2 files changed

+60
-54
lines changed

2 files changed

+60
-54
lines changed

src/ODE.jl

+55-50
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ export ode23, ode4, ode45, ode4s, ode4ms
4343
# Initialize variables.
4444
# Adapted from Cleve Moler's textbook
4545
# http://www.mathworks.com/moler/ncm/ode23tx.m
46-
function ode23{T}(F::Function, tspan::AbstractVector, y0::AbstractVector{T}; reltol = 1.e-5, abstol = 1.e-8)
47-
46+
function ode23(F, tspan, y0; reltol = 1.e-5, abstol = 1.e-8)
4847
if reltol == 0
4948
warn("setting reltol = 0 gives a step size of zero")
5049
end
@@ -58,7 +57,8 @@ function ode23{T}(F::Function, tspan::AbstractVector, y0::AbstractVector{T}; rel
5857
y = y0
5958

6059
tout = t
61-
yout = y.'
60+
yout = Array(typeof(y0),1)
61+
yout[1] = y
6262

6363
tlen = length(t)
6464

@@ -101,7 +101,7 @@ function ode23{T}(F::Function, tspan::AbstractVector, y0::AbstractVector{T}; rel
101101
t = tnew
102102
y = ynew
103103
tout = [tout; t]
104-
yout = [yout; y.']
104+
push!(yout, y)
105105
s1 = s4 # Reuse final function value to start new step
106106
end
107107

@@ -118,7 +118,7 @@ function ode23{T}(F::Function, tspan::AbstractVector, y0::AbstractVector{T}; rel
118118

119119
end # while (t != tfinal)
120120

121-
return (tout, yout)
121+
return tout, yout
122122

123123
end # ode23
124124

@@ -180,9 +180,7 @@ end # ode23
180180
181181
# created : 06 October 1999
182182
# modified: 17 January 2001
183-
184-
function oderkf{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, a, b4, b5; reltol = 1.0e-5, abstol = 1.0e-8)
185-
183+
function oderkf(F, tspan, x0, a, b4, b5; reltol = 1.0e-5, abstol = 1.0e-8)
186184
# see p.91 in the Ascher & Petzold reference for more infomation.
187185
pow = 1/5
188186

@@ -196,25 +194,26 @@ function oderkf{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, a,
196194
h = (tfinal - t)/100 # initial guess at a step size
197195
x = x0
198196
tout = t # first output time
199-
xout = x.' # first output solution
197+
xout = Array(typeof(x0), 1)
198+
xout[1] = x # first output solution
200199

201-
k = zeros(eltype(x), length(c), length(x))
202-
k[1,:] = F(t,x) # first stage
200+
k = Array(typeof(x0), length(c))
201+
k[1] = F(t,x) # first stage
203202

204203
while t < tfinal && h >= hmin
205204
if t + h > tfinal
206205
h = tfinal - t
207206
end
208207

209208
for j = 2:length(c)
210-
k[j,:] = F(t + h.*c[j], x + h.*(a[j,1:j-1]*k[1:j-1,:]).')
209+
k[j] = F(t + h.*c[j], x + h.*(a[j,1:j-1]*k[1:j-1])[1])
211210
end
212211

213212
# compute the 4th order estimate
214-
x4 = x + h.*(b4*k).'
213+
x4 = x + h.*(b4*k)[1]
215214

216215
# compute the 5th order estimate
217-
x5 = x + h.*(b5*k).'
216+
x5 = x + h.*(b5*k)[1]
218217

219218
# estimate the local truncation error
220219
gamma1 = x5 - x4
@@ -228,7 +227,7 @@ function oderkf{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, a,
228227
t = t + h
229228
x = x5 # <-- using the higher order estimate is called 'local extrapolation'
230229
tout = [tout; t]
231-
xout = [xout; x.']
230+
push!(xout, x)
232231

233232
# Compute the slopes by computing the k[:,j+1]'th column based on the previous k[:,1:j] columns
234233
# notes: k needs to end up as an Nxs, a is 7x6, which is s by (s-1),
@@ -238,9 +237,9 @@ function oderkf{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, a,
238237
# This is part of the Dormand-Prince pair caveat.
239238
# k[:,7] has already been computed, so use it instead of recomputing it
240239
# again as k[:,1] during the next step.
241-
k[1,:] = k[end,:]
240+
k[1] = k[end]
242241
else
243-
k[1,:] = F(t,x) # first stage
242+
k[1] = F(t,x) # first stage
244243
end
245244
end
246245

@@ -252,7 +251,7 @@ function oderkf{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, a,
252251
println("Step size grew too small. t=", t, ", h=", h, ", x=", x)
253252
end
254253

255-
return (tout, xout)
254+
return tout, xout
256255
end
257256

258257
# Both the Dormand-Prince and Fehlberg 4(5) coefficients are from a tableau in
@@ -316,59 +315,65 @@ const ode45 = ode45_dp
316315
# ODEFUN(T,X) must return a column vector corresponding to f(t,x). Each
317316
# row in the solution array X corresponds to a time returned in the
318317
# column vector T.
319-
function ode4{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T})
318+
function ode4(F, tspan, x0)
320319
h = diff(tspan)
321-
x = Array(T, length(tspan), length(x0))
322-
x[1,:] = x0
320+
x = Array(typeof(x0), length(tspan))
321+
x[1] = x0
323322

324-
midxdot = Array(T, 4, length(x0))
323+
midxdot = Array(typeof(x0), 4)
325324
for i = 1:length(tspan)-1
326325
# Compute midstep derivatives
327-
midxdot[1,:] = F(tspan[i], x[i,:]')
328-
midxdot[2,:] = F(tspan[i]+h[i]./2, x[i,:]' + midxdot[1,:]'.*h[i]./2)
329-
midxdot[3,:] = F(tspan[i]+h[i]./2, x[i,:]' + midxdot[2,:]'.*h[i]./2)
330-
midxdot[4,:] = F(tspan[i]+h[i], x[i,:]' + midxdot[3,:]'.*h[i])
326+
midxdot[1] = F(tspan[i], x[i])
327+
midxdot[2] = F(tspan[i]+h[i]./2, x[i] + midxdot[1].*h[i]./2)
328+
midxdot[3] = F(tspan[i]+h[i]./2, x[i] + midxdot[2].*h[i]./2)
329+
midxdot[4] = F(tspan[i]+h[i], x[i] + midxdot[3].*h[i])
331330

332331
# Integrate
333-
x[i+1,:] = x[i,:] + 1./6.*h[i].*[1 2 2 1]*midxdot
332+
x[i+1] = x[i] + 1./6.*(h[i].*[1 2 2 1]*midxdot)[1]
334333
end
335-
return (tspan, x)
334+
return tspan, x
336335
end
337336

338337
#ODEROSENBROCK Solve stiff differential equations, Rosenbrock method
339338
# with provided coefficients.
340-
function oderosenbrock{T}(F::Function, G::Function, tspan::AbstractVector, x0::AbstractVector{T}, gamma, a, b, c)
339+
function oderosenbrock(F, G, tspan, x0, gamma, a, b, c)
341340
h = diff(tspan)
342-
x = Array(T, length(tspan), length(x0))
343-
x[1,:] = x0
341+
x = Array(typeof(x0), length(tspan))
342+
x[1] = x0
344343

345344
solstep = 1
346345
while tspan[solstep] < maximum(tspan)
347346
ts = tspan[solstep]
348347
hs = h[solstep]
349-
xs = reshape(x[solstep,:], size(x0))
348+
xs = x[solstep]
350349
dFdx = G(ts, xs)
351-
jac = eye(size(dFdx,1))./gamma./hs-dFdx
350+
# FIXME
351+
if size(dFdx,1) == 1
352+
jac = 1/gamma/hs - dFdx[1]
353+
else
354+
jac = eye(dFdx)./gamma./hs - dFdx
355+
end
352356

353-
g = zeros(size(a,1), length(x0))
354-
g[1,:] = jac \ F(ts + b[1].*hs, xs)
357+
g = Array(typeof(x0), size(a,1))
358+
g[1] = (jac \ F(ts + b[1].*hs, xs))
355359
for i = 2:size(a,1)
356-
g[i,:] = jac \ (F(ts + b[i].*hs, xs + (a[i,1:i-1]*g[1:i-1,:]).') + (c[i,1:i-1]*g[1:i-1,:]).'./hs)
360+
g[i] = (jac \ (F(ts + b[i].*hs, xs + (a[i,1:i-1]*g[1:i-1])[1]) + (c[i,1:i-1]*g[1:i-1])[1]./hs))
357361
end
358-
359-
x[solstep+1,:] = x[solstep,:] + b*g
362+
x[solstep+1] = x[solstep] + (b*g)[1]
360363
solstep += 1
361364
end
362-
return (tspan, x)
365+
return tspan, x
363366
end
364367

365-
function oderosenbrock{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, gamma, a, b, c)
368+
function oderosenbrock(F, tspan, x0, gamma, a, b, c)
366369
# Crude forward finite differences estimator as fallback
367-
function jacobian(F::Function, t::Number, x::AbstractVector)
370+
# FIXME: This doesn't really work if x is anything but a Vector or a scalar
371+
function jacobian(F, t, x)
368372
ftx = F(t, x)
369-
dFdx = zeros(length(x), length(x))
370-
for j = 1:length(x)
371-
dx = zeros(size(x))
373+
lx = max(length(x),1)
374+
dFdx = zeros(eltype(x), lx, lx)
375+
for j = 1:lx
376+
dx = zeros(eltype(x), lx)
372377
# The 100 below is heuristic
373378
dx[j] = (x[j]+(x[j]==0))./100
374379
dFdx[:,j] = (F(t,x+dx)-ftx)./dx[j]
@@ -412,10 +417,10 @@ ode4s_s(F, G, tspan, x0) = oderosenbrock(F, G, tspan, x0, s4_coefficients...)
412417
const ode4s = ode4s_s
413418

414419
# ODE_MS Fixed-step, fixed-order multi-step numerical method with Adams-Bashforth-Moulton coefficients
415-
function ode_ms{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, order::Integer)
420+
function ode_ms(F, tspan, x0, order::Integer)
416421
h = diff(tspan)
417-
x = zeros(T, length(tspan), length(x0))
418-
x[1,:] = x0
422+
x = Array(typeof(x0), length(tspan))
423+
x[1] = x0
419424

420425
if 1 <= order <= 4
421426
b = [ 1 0 0 0
@@ -438,10 +443,10 @@ function ode_ms{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, or
438443
for i = 1:length(tspan)-1
439444
# Need to run the first several steps at reduced order
440445
steporder = min(i, order)
441-
xdot[i,:] = F(tspan[i], x[i,:]')
442-
x[i+1,:] = x[i,:] + b[steporder,1:steporder]*xdot[i-(steporder-1):i,:].*h[i]
446+
xdot[i] = F(tspan[i], x[i])
447+
x[i+1] = x[i] + (b[steporder,1:steporder]*xdot[i-(steporder-1):i])[1].*h[i]
443448
end
444-
return (tspan, x)
449+
return tspan, x
445450
end
446451

447452
# Use order 4 by default

test/tests.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,19 @@ for solver in solvers
1919
# dy
2020
# -- = 6 ==> y = 6t
2121
# dt
22-
t,y=solver((t,y)->6, [0:.1:1], [0.])
22+
t,y=solver((t,y)->6, [0:.1:1], 0.)
2323
@test maximum(abs(y-6t)) < tol
2424

2525
# dy
2626
# -- = 2t ==> y = t.^2
2727
# dt
28-
t,y=solver((t,y)->2t, [0:.001:1], [0.])
28+
t,y=solver((t,y)->2t, [0:.001:1], 0.)
2929
@test maximum(abs(y-t.^2)) < tol
3030

3131
# dy
3232
# -- = y ==> y = y0*e.^t
3333
# dt
34-
t,y=solver((t,y)->y, [0:.001:1], [1.])
34+
t,y=solver((t,y)->y, [0:.001:1], 1.)
3535
@test maximum(abs(y-e.^t)) < tol
3636

3737
# dv dw
@@ -40,7 +40,8 @@ for solver in solvers
4040
#
4141
# y = [v, w]
4242
t,y=solver((t,y)->[-y[2], y[1]], [0:.001:2*pi], [1., 2.])
43-
@test maximum(abs(y-[cos(t)-2*sin(t) 2*cos(t)+sin(t)])) < tol
43+
ys = hcat(y...).' # convert Vector{Vector{Float}} to Matrix{Float}
44+
@test maximum(abs(ys-[cos(t)-2*sin(t) 2*cos(t)+sin(t)])) < tol
4445
end
4546

4647
println("All looks OK")

0 commit comments

Comments
 (0)