Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: create and solve initialization system in linearization_function #2676

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.12"
SymbolicUtils = "1.0"
SymbolicUtils = "<1.6"
Symbolics = "5.26"
URIs = "1"
UnPack = "0.1, 1.0"
Expand Down
130 changes: 114 additions & 16 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,14 @@
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
end

function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
return let _fn = build_explicit_observed_function(sys, sym)
fn(u, p, t) = _fn(u, p, t)
fn(u, p::MTKParameters, t) = _fn(u, p..., t)
fn
end
end

function SymbolicIndexingInterface.default_values(sys::AbstractSystem)
return merge(
Dict(eq.lhs => eq.rhs for eq in observed(sys)),
Expand Down Expand Up @@ -1020,7 +1028,15 @@
isempty(systems) ? defs : mapfoldr(namespace_defaults, merge, systems; init = defs)
end

function defaults_and_guesses(sys::AbstractSystem)
merge(guesses(sys), defaults(sys))
end

unknowns(sys::Union{AbstractSystem, Nothing}, v) = renamespace(sys, v)
for vType in [Symbolics.Arr, Symbolics.Symbolic{<:AbstractArray}]
@eval unknowns(sys::AbstractSystem, v::$vType) = renamespace(sys, v)
@eval parameters(sys::AbstractSystem, v::$vType) = toparam(unknowns(sys, v))

Check warning on line 1038 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1037-L1038

Added lines #L1037 - L1038 were not covered by tests
end
parameters(sys::Union{AbstractSystem, Nothing}, v) = toparam(unknowns(sys, v))
for f in [:unknowns, :parameters]
@eval function $f(sys::AbstractSystem, vs::AbstractArray)
Expand Down Expand Up @@ -1756,34 +1772,117 @@
op = merge(defs, op)
end
sys = ssys
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
ps = parameters(sys)
initsys = complete(generate_initializesystem(
sys, guesses = guesses(sys), algebraic_only = true))
if p isa SciMLBase.NullParameters
p = Dict()
else
p = todict(p)

Check warning on line 1780 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1780

Added line #L1780 was not covered by tests
end
x0 = merge(defaults_and_guesses(sys), op)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, p, u0)
sys_ps = MTKParameters(sys, p, x0)
else
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)

Check warning on line 1786 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1786

Added line #L1786 was not covered by tests
end
p[get_iv(sys)] = NaN
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
initsys_ps = parameters(initsys)
initsys_idxs = [parameter_index(initsys, param) for param in initsys_ps]
tunable_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == SciMLStructures.Tunable()]
tunable_getter = isempty(tunable_ps) ? nothing : getu(sys, tunable_ps)
discrete_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == SciMLStructures.Discrete()]
disc_getter = isempty(discrete_ps) ? nothing : getu(sys, discrete_ps)
constant_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == SciMLStructures.Constants()]
const_getter = isempty(constant_ps) ? nothing : getu(sys, constant_ps)
nonnum_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == NONNUMERIC_PORTION]
nonnum_getter = isempty(nonnum_ps) ? nothing : getu(sys, nonnum_ps)
u_getter = isempty(unknowns(initsys)) ? (_...) -> nothing :
getu(sys, unknowns(initsys))
get_initprob_u_p = let tunable_getter = tunable_getter,
disc_getter = disc_getter,
const_getter = const_getter,
nonnum_getter = nonnum_getter,
oldps = oldps,
u_getter = u_getter

function (u, p, t)
state = ProblemState(; u, p, t)
if tunable_getter !== nothing
SciMLStructures.replace!(
SciMLStructures.Tunable(), oldps, tunable_getter(state))
end
if disc_getter !== nothing
SciMLStructures.replace!(

Check warning on line 1825 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1825

Added line #L1825 was not covered by tests
SciMLStructures.Discrete(), oldps, disc_getter(state))
end
if const_getter !== nothing
SciMLStructures.replace!(

Check warning on line 1829 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1829

Added line #L1829 was not covered by tests
SciMLStructures.Constants(), oldps, const_getter(state))
end
if nonnum_getter !== nothing
SciMLStructures.replace!(

Check warning on line 1833 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1833

Added line #L1833 was not covered by tests
NONNUMERIC_PORTION, oldps, nonnum_getter(state))
end
newu = u_getter(state)
return newu, oldps
end
end
else
p = _p
p, split_idxs = split_parameters_by_type(p)
if p isa Tuple
ps = Base.Fix1(getindex, ps).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple
get_initprob_u_p = let p_getter = getu(sys, parameters(initsys)),

Check warning on line 1841 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1841

Added line #L1841 was not covered by tests
u_getter = getu(sys, unknowns(initsys))

function (u, p, t)
state = ProblemState(; u, p, t)
return u_getter(state), p_getter(state)

Check warning on line 1846 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1844-L1846

Added lines #L1844 - L1846 were not covered by tests
end
end
end
initfn = NonlinearFunction(initsys)
initprobmap = getu(initsys, unknowns(sys))
ps = full_parameters(sys)
lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
input_idxs = input_idxs,
sts = unknowns(sys),
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys, unknowns(sys), ps; p = p),
get_initprob_u_p = get_initprob_u_p,
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
sys, unknowns(sys), ps; initializeprobmap = initprobmap),
initfn = initfn,
h = build_explicit_observed_function(sys, outputs),
chunk = ForwardDiff.Chunk(input_idxs)
chunk = ForwardDiff.Chunk(input_idxs),
sys_ps = sys_ps,
initialize = initialize,
sys = sys

function (u, p, t)
if !isa(p, MTKParameters)
p = todict(p)
newps = deepcopy(sys_ps)
for (k, v) in p
setp(sys, k)(newps, v)
end

Check warning on line 1873 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1872-L1873

Added lines #L1872 - L1873 were not covered by tests
p = newps
end

if u !== nothing # Handle systems without unknowns
length(sts) == length(u) ||
error("Number of unknown variables ($(length(sts))) does not match the number of input unknowns ($(length(u)))")
if initialize && !isempty(alge_idxs) # This is expensive and can be omitted if the user knows that the system is already initialized
residual = fun(u, p, t)
if norm(residual[alge_idxs]) > √(eps(eltype(residual)))
initu0, initp = get_initprob_u_p(u, p, t)
initprob = NonlinearLeastSquaresProblem(initfn, initu0, initp)
@set! fun.initializeprob = initprob
prob = ODEProblem(fun, u, (t, t + 1), p)
integ = init(prob, OrdinaryDiffEq.Rodas5P())
u = integ.u
Expand Down Expand Up @@ -2051,21 +2150,20 @@
"""
function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives = false,
p = DiffEqBase.NullParameters())
x0 = merge(defaults(sys), op)
u0, p2, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
u0, defs = get_u0(sys, x0, p)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
if p isa SciMLBase.NullParameters
p = op
p = Dict()
elseif p isa Dict
p = merge(p, op)
elseif p isa Vector && eltype(p) <: Pair
p = merge(Dict(p), op)
elseif p isa Vector
p = merge(Dict(parameters(sys) .=> p), op)
end
p2 = MTKParameters(sys, p, Dict(unknowns(sys) .=> u0))
end
linres = lin_fun(u0, p2, t)
linres = lin_fun(u0, p, t)
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres

nx, nu = size(f_u)
Expand Down
12 changes: 9 additions & 3 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1632,10 +1632,16 @@
parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
[get_iv(sys) => t] :
merge(todict(parammap), Dict(get_iv(sys) => t))

if isempty(u0map)
u0map = Dict()

Check warning on line 1636 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L1636

Added line #L1636 was not covered by tests
end
if isempty(guesses)
guesses = Dict()

Check warning on line 1639 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L1639

Added line #L1639 was not covered by tests
end
u0map = merge(todict(guesses), todict(u0map))
if neqs == nunknown
NonlinearProblem(isys, guesses, parammap)
NonlinearProblem(isys, u0map, parammap)
else
NonlinearLeastSquaresProblem(isys, guesses, parammap)
NonlinearLeastSquaresProblem(isys, u0map, parammap)
end
end
6 changes: 4 additions & 2 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ function BufferTemplate(s::Type{<:Symbolics.Struct}, length::Int)
BufferTemplate(T, length)
end

const DEPENDENT_PORTION = :dependent
const NONNUMERIC_PORTION = :nonnumeric
struct Dependent <: SciMLStructures.AbstractPortion end
struct Nonnumeric <: SciMLStructures.AbstractPortion end
const DEPENDENT_PORTION = Dependent()
const NONNUMERIC_PORTION = Nonnumeric()

struct ParameterIndex{P, I}
portion::P
Expand Down
37 changes: 22 additions & 15 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
name = nameof(sys),
guesses = Dict(), check_defguess = false,
default_dd_value = 0.0,
algebraic_only = false,
kwargs...)
sts, eqs = unknowns(sys), equations(sys)
idxs_diff = isdiffeq.(eqs)
Expand Down Expand Up @@ -68,28 +69,34 @@
defs = merge(defaults(sys), filtered_u0)
guesses = merge(get_guesses(sys), todict(guesses), dd_guess)

for st in full_states
if st ∈ keys(defs)
def = defs[st]
if !algebraic_only
for st in full_states
if st ∈ keys(defs)
def = defs[st]

if def isa Equation
st ∉ keys(guesses) && check_defguess &&
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
push!(eqs_ics, def)
if def isa Equation
st ∉ keys(guesses) && check_defguess &&

Check warning on line 78 in src/systems/nonlinear/initializesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/initializesystem.jl#L78

Added line #L78 was not covered by tests
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
push!(eqs_ics, def)
push!(u0, st => guesses[st])

Check warning on line 81 in src/systems/nonlinear/initializesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/initializesystem.jl#L80-L81

Added lines #L80 - L81 were not covered by tests
else
push!(eqs_ics, st ~ def)
push!(u0, st => def)
end
elseif st ∈ keys(guesses)
push!(u0, st => guesses[st])
else
push!(eqs_ics, st ~ def)
push!(u0, st => def)
elseif check_defguess
error("Invalid setup: unknown $(st) has no default value or initial guess")

Check warning on line 89 in src/systems/nonlinear/initializesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/initializesystem.jl#L89

Added line #L89 was not covered by tests
end
elseif st ∈ keys(guesses)
push!(u0, st => guesses[st])
elseif check_defguess
error("Invalid setup: unknown $(st) has no default value or initial guess")
end
end

pars = [parameters(sys); get_iv(sys)]
nleqs = [eqs_ics; get_initialization_eqs(sys); observed(sys)]
nleqs = if algebraic_only
[eqs_ics; observed(sys)]
else
[eqs_ics; get_initialization_eqs(sys); observed(sys)]
end

sys_nl = NonlinearSystem(nleqs,
full_states,
Expand Down
5 changes: 3 additions & 2 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@
end

function _update_tuple_helper(::Type{<:AbstractArray}, buf_v, raw, idx)
ntuple(i -> _update_tuple_helper(buf_v[i], raw, idx), Val(length(buf_v)))
ntuple(i -> _update_tuple_helper(buf_v[i], raw, idx), length(buf_v))

Check warning on line 193 in src/systems/parameter_buffer.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/parameter_buffer.jl#L193

Added line #L193 was not covered by tests
end

function _update_tuple_helper(::Any, buf_v, raw, idx)
Expand All @@ -210,7 +210,8 @@

for (Portion, field) in [(SciMLStructures.Tunable, :tunable)
(SciMLStructures.Discrete, :discrete)
(SciMLStructures.Constants, :constant)]
(SciMLStructures.Constants, :constant)
(Nonnumeric, :nonnumeric)]
@eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters)
as_vector = buffer_to_arraypartition(p.$field)
repack = let as_vector = as_vector, p = p
Expand Down
4 changes: 2 additions & 2 deletions test/downstream/inversemodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ x, _ = ModelingToolkit.get_u0_p(simplified_sys, op)
p = ModelingToolkit.MTKParameters(simplified_sys, op)
matrices1 = Sf(x, p, 0)
matrices2, _ = Blocks.get_sensitivity(model, :y; op) # Test that we get the same result when calling the higher-level API
@test matrices1.f_x ≈ matrices2.A[1:7, 1:7]
@test_broken matrices1.f_x ≈ matrices2.A[1:7, 1:7]
nsys = get_named_sensitivity(model, :y; op) # Test that we get the same result when calling an even higher-level API
@test matrices2.A ≈ nsys.A

Expand All @@ -161,6 +161,6 @@ x, _ = ModelingToolkit.get_u0_p(simplified_sys, op)
p = ModelingToolkit.MTKParameters(simplified_sys, op)
matrices1 = Sf(x, p, 0)
matrices2, _ = Blocks.get_comp_sensitivity(model, :y; op) # Test that we get the same result when calling the higher-level API
@test matrices1.f_x ≈ matrices2.A[1:7, 1:7]
@test_broken matrices1.f_x ≈ matrices2.A[1:7, 1:7]
nsys = get_named_comp_sensitivity(model, :y; op) # Test that we get the same result when calling an even higher-level API
@test matrices2.A ≈ nsys.A
64 changes: 64 additions & 0 deletions test/downstream/linearization_dd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
## Test that dummy_derivatives can be set to zero
# The call to Link(; m = 0.2, l = 10, I = 1, g = -9.807) hangs forever on Julia v1.6
using LinearAlgebra
using ModelingToolkit
using ModelingToolkitStandardLibrary
using ModelingToolkitStandardLibrary.Blocks
using ModelingToolkitStandardLibrary.Mechanical.MultiBody2D
using ModelingToolkitStandardLibrary.Mechanical.TranslationalPosition
using Test

using ControlSystemsMTK
using ControlSystemsMTK.ControlSystemsBase: sminreal, minreal, poles
connect = ModelingToolkit.connect

@parameters t
D = Differential(t)

@named link1 = Link(; m = 0.2, l = 10, I = 1, g = -9.807)
@named cart = TranslationalPosition.Mass(; m = 1, s = 0)
@named fixed = Fixed()
@named force = Force(use_support = false)

eqs = [connect(link1.TX1, cart.flange)
connect(cart.flange, force.flange)
connect(link1.TY1, fixed.flange)]

@named model = ODESystem(eqs, t, [], []; systems = [link1, cart, force, fixed])
def = ModelingToolkit.defaults(model)
def[cart.s] = 10
def[cart.v] = 0
def[link1.A] = -pi / 2
def[link1.dA] = 0
lin_outputs = [cart.s, cart.v, link1.A, link1.dA]
lin_inputs = [force.f.u]

@test_broken begin
@info "named_ss"
G = named_ss(model, lin_inputs, lin_outputs, allow_symbolic = true, op = def,
allow_input_derivatives = true, zero_dummy_der = true)
G = sminreal(G)
@info "minreal"
G = minreal(G)
@info "poles"
ps = poles(G)

@test minimum(abs, ps) < 1e-6
@test minimum(abs, complex(0, 1.3777260367206716) .- ps) < 1e-10

lsys, syss = linearize(model, lin_inputs, lin_outputs, allow_symbolic = true, op = def,
allow_input_derivatives = true, zero_dummy_der = true)
lsyss, sysss = ModelingToolkit.linearize_symbolic(model, lin_inputs, lin_outputs;
allow_input_derivatives = true)

dummyder = setdiff(unknowns(sysss), unknowns(model))
def = merge(ModelingToolkit.guesses(model), def, Dict(x => 0.0 for x in dummyder))
def[link1.fy1] = -def[link1.g] * def[link1.m]

@test substitute(lsyss.A, def) ≈ lsys.A
# We cannot pivot symbolically, so the part where a linear solve is required
# is not reliable.
@test substitute(lsyss.B, def)[1:6, 1] ≈ lsys.B[1:6, 1]
@test substitute(lsyss.C, def) ≈ lsys.C
@test substitute(lsyss.D, def) ≈ lsys.D
end
Loading
Loading