diff --git a/ext/ForwardDiffStaticArraysExt.jl b/ext/ForwardDiffStaticArraysExt.jl index bf0ef99a..7fe4bedc 100644 --- a/ext/ForwardDiffStaticArraysExt.jl +++ b/ext/ForwardDiffStaticArraysExt.jl @@ -3,7 +3,7 @@ module ForwardDiffStaticArraysExt using ForwardDiff, StaticArrays using ForwardDiff.LinearAlgebra using ForwardDiff.DiffResults -using ForwardDiff: Dual, partials, npartials, Partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk, +using ForwardDiff: Dual, partials, npartials, Partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk, maketagtype, gradient, hessian, jacobian, gradient!, hessian!, jacobian!, extract_gradient!, extract_jacobian!, extract_value!, vector_mode_gradient, vector_mode_gradient!, @@ -51,12 +51,12 @@ ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDi end @inline function ForwardDiff.vector_mode_gradient(f::F, x::StaticArray) where {F} - T = typeof(Tag(f, eltype(x))) + T = maketagtype(f,eltype(x)) return extract_gradient(T, f(dualize(T, x)), x) end @inline function ForwardDiff.vector_mode_gradient!(result, f::F, x::StaticArray) where {F} - T = typeof(Tag(f, eltype(x))) + T = maketagtype(f,eltype(x)) return extract_gradient!(T, result, f(dualize(T, x))) end @@ -81,7 +81,7 @@ end end @inline function ForwardDiff.vector_mode_jacobian(f::F, x::StaticArray) where {F} - T = typeof(Tag(f, eltype(x))) + T = maketagtype(f,eltype(x)) return extract_jacobian(T, f(dualize(T, x)), x) end @@ -91,7 +91,7 @@ function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where end @inline function ForwardDiff.vector_mode_jacobian!(result, f::F, x::StaticArray) where {F} - T = typeof(Tag(f, eltype(x))) + T = maketagtype(f,eltype(x)) ydual = f(dualize(T, x)) result = extract_jacobian!(T, result, ydual, length(x)) result = extract_value!(T, result, ydual) @@ -99,7 +99,7 @@ end end @inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F} - T = typeof(Tag(f, eltype(x))) + T = maketagtype(f,eltype(x)) ydual = f(dualize(T, x)) result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x)) result = DiffResults.value!(Base.Fix1(value, T), result, ydual) @@ -119,7 +119,7 @@ ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::Hes ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig, ::Val) where {F} = hessian!(result, f, x) function ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F} - T = typeof(Tag(f, eltype(x))) + T = maketagtype(f,eltype(x)) d1 = dualize(T, x) d2 = dualize(T, d1) fd2 = f(d2) diff --git a/src/config.jl b/src/config.jl index 3c6c97e3..1fd2cbac 100644 --- a/src/config.jl +++ b/src/config.jl @@ -1,9 +1,7 @@ ####### # Tag # ####### - -struct Tag{F,V} -end +struct Tag{F,V} <: AbstractTag{F,V} end const TAGCOUNT = Threads.Atomic{UInt}(0) @@ -20,11 +18,13 @@ end Tag(::Nothing, ::Type{V}) where {V} = nothing - @inline function ≺(::Type{Tag{F1,V1}}, ::Type{Tag{F2,V2}}) where {F1,V1,F2,V2} tagcount(Tag{F1,V1}) < tagcount(Tag{F2,V2}) end +#default implementation of maketag. +maketag(f::F,::Type{V}) where {F,V} = Tag(f,V) + struct InvalidTagException{E,O} <: Exception end @@ -38,8 +38,16 @@ checktag(::Type{Tag{F,V}}, f::F, x::AbstractArray{V}) where {F,V} = true # no easy way to check Jacobian tag used with Hessians as multiple functions may be used checktag(::Type{Tag{FT,VT}}, f::F, x::AbstractArray{V}) where {FT<:Tuple,VT,F,V} = true +checktag(::Type{AbstractTag{FT,VT}}, f::F, x::AbstractArray{V}) where {FT<:Tuple,VT,F,V} = true + +#AbstractTag support +function checktag(T::Type{AbstractTag{FT,VT}}, f::F, x::AbstractArray{V}) where {FT,VT,F,V} + T2 = maketagtype(f,V) #maketag(f::F,type{V})::AbstractTag{F2,V2} is not equivalent to Tag{F,V} + T2 !== T && throw(InvalidTagException{T,T2}()) + return true +end -# custom tag: you're on your own. +# custom tag not in the tag API: you're on your own. checktag(z, f, x) = true @@ -82,7 +90,7 @@ This constructor does not store/modify `y` or `x`. function DerivativeConfig(f::F, y::AbstractArray{Y}, x::X, - tag::T = Tag(f, X)) where {F,X<:Real,Y<:Real,T} + tag::T = maketag(f, X)) where {F,X<:Real,Y<:Real,T} duals = similar(y, Dual{T,Y,1}) return DerivativeConfig{T,typeof(duals)}(duals) end @@ -117,7 +125,7 @@ This constructor does not store/modify `x`. function GradientConfig(f::F, x::AbstractArray{V}, ::Chunk{N} = Chunk(x), - ::T = Tag(f, V)) where {F,V,N,T} + ::T = maketag(f, V)) where {F,V,N,T} seeds = construct_seeds(Partials{N,V}) duals = similar(x, Dual{T,V,N}) return GradientConfig{T,V,N,typeof(duals)}(seeds, duals) @@ -154,7 +162,7 @@ This constructor does not store/modify `x`. function JacobianConfig(f::F, x::AbstractArray{V}, ::Chunk{N} = Chunk(x), - ::T = Tag(f, V)) where {F,V,N,T} + ::T = maketag(f, V)) where {F,V,N,T} seeds = construct_seeds(Partials{N,V}) duals = similar(x, Dual{T,V,N}) return JacobianConfig{T,V,N,typeof(duals)}(seeds, duals) @@ -180,7 +188,7 @@ function JacobianConfig(f::F, y::AbstractArray{Y}, x::AbstractArray{X}, ::Chunk{N} = Chunk(x), - ::T = Tag(f, X)) where {F,Y,X,N,T} + ::T = maketag(f, X)) where {F,Y,X,N,T} seeds = construct_seeds(Partials{N,X}) yduals = similar(y, Dual{T,Y,N}) xduals = similar(x, Dual{T,X,N}) @@ -221,7 +229,7 @@ This constructor does not store/modify `x`. function HessianConfig(f::F, x::AbstractArray{V}, chunk::Chunk = Chunk(x), - tag = Tag(f, V)) where {F,V} + tag = maketag(f, V)) where {F,V} jacobian_config = JacobianConfig(f, x, chunk, tag) gradient_config = GradientConfig(f, jacobian_config.duals, chunk, tag) return HessianConfig(jacobian_config, gradient_config) @@ -246,7 +254,7 @@ function HessianConfig(f::F, result::DiffResult, x::AbstractArray{V}, chunk::Chunk = Chunk(x), - tag = Tag(f, V)) where {F,V} + tag = maketag(f, V)) where {F,V} jacobian_config = JacobianConfig((f,gradient), DiffResults.gradient(result), x, chunk, tag) gradient_config = GradientConfig(f, jacobian_config.duals[2], chunk, tag) return HessianConfig(jacobian_config, gradient_config) diff --git a/src/derivative.jl b/src/derivative.jl index b39e2a48..2ec0b8fc 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -10,7 +10,7 @@ Return `df/dx` evaluated at `x`, assuming `f` is called as `f(x)`. This method assumes that `isa(f(x), Union{Real,AbstractArray})`. """ @inline function derivative(f::F, x::R) where {F,R<:Real} - T = typeof(Tag(f, R)) + T = maketagtype(f,R) return extract_derivative(T, f(Dual{T}(x, one(x)))) end @@ -44,7 +44,7 @@ This method assumes that `isa(f(x), Union{Real,AbstractArray})`. @inline function derivative!(result::Union{AbstractArray,DiffResult}, f::F, x::R) where {F,R<:Real} result isa DiffResult || require_one_based_indexing(result) - T = typeof(Tag(f, R)) + T = maketagtype(f,R) ydual = f(Dual{T}(x, one(x))) result = extract_value!(T, result, ydual) result = extract_derivative!(T, result, ydual) diff --git a/src/prelude.jl b/src/prelude.jl index 9e037afa..48aed5ee 100644 --- a/src/prelude.jl +++ b/src/prelude.jl @@ -67,3 +67,22 @@ function qualified_cse!(expr) end return cse_expr end + +#= +###### +# AbstractTag interface + +Required definitions: +- ≺ (between two AbstractTags of the same type) +- maketag(f,::Type{V}) where {V <: Real} + +Optional definitions: +- ≺ (between two AbstractTags of the of different type) +- maketagtype(f,::Type{V}) where {V <: Real} (default: defined in terms of maketag) +- checktag(tag::MyTagType,f,x) +###### +=# +abstract type AbstractTag{F,V} end + +maketag(::Nothing,::Type{V}) where {V} = nothing +@inline maketagtype(f::F,::Type{V}) where {F,V} = typeof(maketag(f,V))