Skip to content

Commit 96a5c27

Browse files
AbstractODEFunction
1 parent 073ea74 commit 96a5c27

File tree

3 files changed

+87
-104
lines changed

3 files changed

+87
-104
lines changed

src/maketype.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,23 @@ function maketype(name,param_dict,origex,funcs,syms,fex;
99
params = Symbol[],
1010
pfuncs=Vector{Expr}(0),
1111
d_pfuncs = Vector{Expr}(0),
12-
param_Jex=:())
12+
param_Jex=:(),
13+
f_expr=:(),
14+
tgrad_expr=:(),
15+
jac_expr=:(),
16+
invjac_expr=:(),
17+
invW_expr=:(),
18+
invW_t_expr=:(),
19+
param_jac_expr=:())
1320

14-
typeex = :(mutable struct $name <: DiffEqBase.AbstractParameterizedFunction{true}
21+
typeex = :(mutable struct $name{F,J,T,W,Wt,PJ} <: DiffEqBase.AbstractODEFunction{true}
22+
f::F
23+
analytic::Nothing
24+
jac::J
25+
tgrad::T
26+
invW::W
27+
invW_t::Wt
28+
paramjac::PJ
1529
origex::Expr
1630
funcs::Vector{Expr}
1731
pfuncs::Vector{Expr}
@@ -48,7 +62,9 @@ function maketype(name,param_dict,origex,funcs,syms,fex;
4862
vector_ex_ex = Meta.quot(vector_ex)
4963
vector_ex_return_ex = Meta.quot(vector_ex_return)
5064
param_Jex_ex = Meta.quot(param_Jex)
51-
constructorex = :($(name)(;$(Expr(:kw,:origex,new_ex)),
65+
66+
constructorex = :($(name)(;
67+
$(Expr(:kw,:origex,new_ex)),
5268
$(Expr(:kw,:funcs,funcs)),
5369
$(Expr(:kw,:pfuncs,pfuncs)),
5470
$(Expr(:kw,:d_pfuncs,d_pfuncs)),
@@ -67,11 +83,16 @@ function maketype(name,param_dict,origex,funcs,syms,fex;
6783
$(Expr(:kw,:vector_ex,vector_ex_ex)),
6884
$(Expr(:kw,:vector_ex_return,vector_ex_return_ex)),
6985
$(Expr(:kw,:params,params))) =
70-
$(name)(origex,funcs,pfuncs,d_pfuncs,syms,
86+
$(name)($f_expr,nothing,
87+
$jac_expr,$tgrad_expr,$invW_expr,$invW_t_expr,$param_jac_expr,
88+
origex,funcs,pfuncs,d_pfuncs,syms,
7189
tgradex,Jex,expJex,param_Jex,
7290
invJex,invWex,invWex_t,
7391
Hex,invHex,fex,pex,vector_ex,vector_ex_return,params)) |> esc
7492

93+
callex = :(((f::$name))(args...) = f.f(args...)) |> esc
94+
callex2 = :(((f::$name))(u,p,t::Number) = (du=similar(u);f.f(du,u,p,t);du)) |> esc
95+
7596
# Make the type instance using the default constructor
76-
typeex,constructorex
97+
typeex,constructorex,callex,callex2
7798
end

src/ode_def_opts.jl

Lines changed: 47 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -231,124 +231,82 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M
231231
end
232232
end
233233

234-
# Build the type
235-
exprs = Vector{Expr}(undef, 0)
236-
237-
typeex,constructorex = maketype(name,params,origex,funcs,syms,fex,pex=pex,
238-
vector_ex = vector_ex,vector_ex_return = vector_ex_return,
239-
tgradex=tgradex,expJex=expJex,Jex=Jex,
240-
invWex=invWex,invWex_t=invWex_t,
241-
invJex=invJex,Hex=Hex,
242-
invHex=invHex,params=params,
243-
pfuncs=pfuncs,d_pfuncs=d_pfuncs,
244-
param_Jex=param_Jex)
245-
246-
push!(exprs,typeex)
247-
push!(exprs,constructorex)
248-
249-
250-
251-
#=
252-
# Value Dispatches for the Parameters
253-
for i in 1:length(params)
254-
param = Symbol(params[i])
255-
param_func = pfuncs[i]
256-
param_valtype = Val{param}
257-
overloadex = :(((f::$name))(::Type{$param_valtype},t,internal_var___u,$param,internal_var___du) = $param_func) |> esc
258-
push!(exprs,overloadex)
259-
end
260-
=#
261-
262234
# Build the Function
263-
overloadex = :(((f::$name))(internal_var___du,internal_var___u,internal_var___p,t::Number) = $pex) |> esc
264-
push!(exprs,overloadex)
235+
f_expr = :((internal_var___du,internal_var___u,internal_var___p,t::Number) -> $pex)
265236

237+
#=
266238
# Add a method which allocates the `du` and returns it instead of being inplace
267239
overloadex = :(((f::$name))(u,p,t::Number) = (du=similar(u); f(du,u,p,t); du)) |> esc
268240
push!(exprs,overloadex)
269-
270-
#=
271-
# Build the Vectorized functions
272-
overloadex = :(((internal_var___p::$name))(::Type{Val{:vec}},t::Number,internal_var___u,internal_var___du) = $vector_ex) |> esc
273-
push!(exprs,overloadex)
274-
275-
# Build the Vectorized functions
276-
overloadex = :(((internal_var___p::$name))(::Type{Val{:vec}},t::Number,internal_var___u) = $vector_ex_return) |> esc
277-
push!(exprs,overloadex)
278-
279-
overloadex = :(((internal_var___p::$name))(::Type{Val{:vec}},t::Number,u) = (du=similar(u); p(t,internal_var___u,du); du)) |> esc
280-
push!(exprs,overloadex)
281-
=#
282-
283-
# Value Dispatches for the Parameter Derivatives
284-
#=
285-
if pderiv_exists
286-
for i in 1:length(params)
287-
param = Symbol(params[i])
288-
param_func = d_pfuncs[i]
289-
param_valtype = Val{param}
290-
overloadex = :(((internal_var___p::$name))(::Type{Val{:deriv}},::Type{$param_valtype},t,internal_var___u,$param,internal_var___du) = $param_func) |> esc
291-
push!(exprs,overloadex)
292-
end
293-
end
294241
=#
295242

296243
# Add the t gradient
297244
if tgrad_exists
298-
overloadex = :(((f::$name))(::Type{Val{:tgrad}},internal_var___grad,internal_var___u,internal_var___p,t) = $tgradex) |> esc
299-
push!(exprs,overloadex)
245+
tgrad_expr = :((internal_var___grad,internal_var___u,internal_var___p,t) -> $tgradex)
246+
else
247+
tgrad_expr = :(nothing)
300248
end
301249

302250
# Add the Jacobian
303251
if jac_exists
304-
overloadex = :(((f::$name))(::Type{Val{:jac}},internal_var___J,internal_var___u,internal_var___p,t) = $Jex) |> esc
305-
push!(exprs,overloadex)
306-
overloadex = :(((f::$name))(::Type{Val{:jac}},u,p,t::Number) = (J=similar(u, (length(u), length(u))); f(Val{:jac},J,u,p,t); J)) |> esc
307-
push!(exprs,overloadex)
308-
end
309-
310-
#=
311-
# Add the Exponential Jacobian
312-
if expjac_exists
313-
overloadex = :(((internal_var___p::$name))(::Type{Val{:expjac}},t,internal_var___u,internal_γ,internal_var___J) = $expJex) |> esc
314-
push!(exprs,overloadex)
252+
jac_expr = :((internal_var___J,internal_var___u,internal_var___p,t) -> $Jex)
253+
else
254+
jac_expr = :(nothing)
315255
end
316-
=#
317256

318257
# Add the Inverse Jacobian
319258
if invjac_exists
320-
overloadex = :(((f::$name))(::Type{Val{:invjac}},internal_var___J,internal_var___u,internal_var___p,t) = $invJex) |> esc
321-
push!(exprs,overloadex)
259+
invjac_expr = :((internal_var___J,internal_var___u,internal_var___p,t) -> $invJex)
260+
else
261+
invjac_expr = :(nothing)
322262
end
263+
323264
# Add the Inverse Rosenbrock-W
324265
if invW_exists
325-
overloadex = :(((f::$name))(::Type{Val{:invW}},internal_var___J,internal_var___u,internal_var___p,internal_γ,t) = $invWex) |> esc
326-
push!(exprs,overloadex)
266+
invW_expr = :((internal_var___J,internal_var___u,internal_var___p,internal_γ,t) -> $invWex)
267+
else
268+
invW_expr = :(nothing)
327269
end
270+
328271
# Add the Inverse Rosenbrock-W Transformed
329272
if invW_exists
330-
overloadex = :(((f::$name))(::Type{Val{:invW_t}},internal_var___J,internal_var___u,internal_var___p,internal_γ,t) = $invWex_t) |> esc
331-
push!(exprs,overloadex)
332-
end
333-
#=
334-
# Add the Hessian
335-
if hes_exists
336-
overloadex = :(((internal_var___p::$name))(::Type{Val{:hes}},t,internal_var___u,internal_var___J) = $Hex) |> esc
337-
push!(exprs,overloadex)
338-
end
339-
# Add the Inverse Hessian
340-
if invhes_exists
341-
overloadex = :(((internal_var___p::$name))(::Type{Val{:invhes}},t,internal_var___u,internal_var___J) = $invHex) |> esc
342-
push!(exprs,overloadex)
273+
invW_t_expr = :((internal_var___J,internal_var___u,internal_var___p,internal_γ,t) -> $invWex_t)
274+
else
275+
invW_t_expr = :(nothing)
343276
end
344-
=#
345277

346278
# Add Parameter Jacobian
347279
if param_jac_exists
348-
overloadex = :(((f::$name))(::Type{Val{:paramjac}},internal_var___J,internal_var___u,internal_var___p,t) = $param_Jex) |> esc
349-
push!(exprs,overloadex)
280+
param_jac_expr = :((internal_var___J,internal_var___u,internal_var___p,t) -> $param_Jex)
281+
else
282+
param_jac_expr = :(nothing)
350283
end
351284

285+
# Build the type
286+
exprs = Vector{Expr}(undef, 0)
287+
288+
typeex,constructorex,callex,callex2 = maketype(name,params,origex,
289+
funcs,syms,fex,
290+
pex=pex,
291+
vector_ex = vector_ex,vector_ex_return = vector_ex_return,
292+
tgradex=tgradex,expJex=expJex,Jex=Jex,
293+
invWex=invWex,invWex_t=invWex_t,
294+
invJex=invJex,Hex=Hex,
295+
invHex=invHex,params=params,
296+
pfuncs=pfuncs,d_pfuncs=d_pfuncs,
297+
param_Jex=param_Jex,
298+
f_expr=f_expr,tgrad_expr=tgrad_expr,
299+
jac_expr=jac_expr,
300+
invjac_expr=invjac_expr,
301+
invW_expr=invW_expr,
302+
invW_t_expr=invW_t_expr,
303+
param_jac_expr=param_jac_expr)
304+
305+
push!(exprs,typeex)
306+
push!(exprs,constructorex)
307+
push!(exprs,callex)
308+
push!(exprs,callex2)
309+
352310
# Return the type from the default consturctor
353311
def_const_ex = :(($name)()) |> esc
354312
push!(exprs,def_const_ex)

test/runtests.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,36 +51,34 @@ f_t2(du,u,p,t)
5151
@test du == [1.0,-3.0]
5252

5353
println("Test t-gradient")
54-
f(Val{:tgrad},grad,u,p,t)
54+
f.tgrad(grad,u,p,t)
5555
@test grad == zeros(2)
56-
f_t(Val{:tgrad},grad,u,p,t)
56+
f_t.tgrad(grad,u,p,t)
5757
@test grad == [0.0;12.0]
5858

5959
println("Test Jacobians")
60-
f(Val{:jac},J,u,p,t)
60+
f.jac(J,u,p,t)
6161
@test J == [-1.5 -2.0
6262
3.0 -1.0]
63-
@test f(Val{:jac}, u, p, t) == [-1.5 -2.0; 3.0 -1.0]
63+
#@test f.jac(u, p, t) == [-1.5 -2.0; 3.0 -1.0]
6464

6565
println("Test Inv Rosenbrock-W")
66-
f(Val{:invW},iW,u,p,2.0,t)
66+
f.invW(iW,u,p,2.0,t)
6767
@test minimum(iW - inv(I - 2*J) .< 1e-10)
6868

69-
f(Val{:invW_t},iW,u,p,2.0,t)
69+
f.invW_t(iW,u,p,2.0,t)
7070
@test minimum(iW - inv(I/2 - J) .< 1e-10)
7171

7272
println("Parameter Jacobians")
7373
pJ = Matrix{Float64}(2,4)
74-
f(Val{:paramjac},pJ,u,[2.0;2.5;3.0;1.0],t)
74+
f.paramjac(pJ,u,[2.0;2.5;3.0;1.0],t)
7575
@test pJ == [2.0 -6.0 0 0.0
7676
0 0 -3.0 6.0]
7777

7878
@code_llvm DiffEqBase.has_jac(f)
7979

8080
println("Test booleans")
8181
@test DiffEqBase.has_jac(f) == true
82-
@test DiffEqBase.has_hes(f) == false
83-
@test DiffEqBase.has_invhes(f) == false
8482
@test DiffEqBase.has_paramjac(f) == true
8583

8684
@code_llvm DiffEqBase.has_paramjac(f)
@@ -107,5 +105,11 @@ NJ(du,u,[1.5,1,3,4],t)
107105
@test_throws MethodError NJ(Val{:jac},iJ,u,p,t)
108106
# NJ(Val{:jac},t,u,J) # Currently gives E not defined, will be fixed by the next SymEgine
109107

110-
println("Make sure all of the problems in the problem library build")
108+
println("Make the problems in the problem library build")
109+
111110
using DiffEqProblemLibrary
111+
using DiffEqProblemLibrary.ODEProblemLibrary: importodeproblems
112+
using DiffEqProblemLibrary.SDEProblemLibrary: importsdeproblems
113+
114+
importodeproblems()
115+
importsdeproblems()

0 commit comments

Comments
 (0)