Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions ext/ForwardDiffStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module ForwardDiffStaticArraysExt
using ForwardDiff, StaticArrays
using ForwardDiff.LinearAlgebra
using ForwardDiff.DiffResults
using ForwardDiff: Dual, partials, npartials, Partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk,
using ForwardDiff: Dual, partials, npartials, Partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk, maketagtype,
gradient, hessian, jacobian, gradient!, hessian!, jacobian!,
extract_gradient!, extract_jacobian!, extract_value!,
vector_mode_gradient, vector_mode_gradient!,
Expand Down Expand Up @@ -51,12 +51,12 @@ ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDi
end

@inline function ForwardDiff.vector_mode_gradient(f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
T = maketagtype(f,eltype(x))
return extract_gradient(T, f(dualize(T, x)), x)
end

@inline function ForwardDiff.vector_mode_gradient!(result, f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
T = maketagtype(f,eltype(x))
return extract_gradient!(T, result, f(dualize(T, x)))
end

Expand All @@ -81,7 +81,7 @@ end
end

@inline function ForwardDiff.vector_mode_jacobian(f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
T = maketagtype(f,eltype(x))
return extract_jacobian(T, f(dualize(T, x)), x)
end

Expand All @@ -91,15 +91,15 @@ function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where
end

@inline function ForwardDiff.vector_mode_jacobian!(result, f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
T = maketagtype(f,eltype(x))
ydual = f(dualize(T, x))
result = extract_jacobian!(T, result, ydual, length(x))
result = extract_value!(T, result, ydual)
return result
end

@inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
T = maketagtype(f,eltype(x))
ydual = f(dualize(T, x))
result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x))
result = DiffResults.value!(Base.Fix1(value, T), result, ydual)
Expand All @@ -119,7 +119,7 @@ ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::Hes
ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig, ::Val) where {F} = hessian!(result, f, x)

function ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
T = maketagtype(f,eltype(x))
d1 = dualize(T, x)
d2 = dualize(T, d1)
fd2 = f(d2)
Expand Down
30 changes: 19 additions & 11 deletions src/config.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#######
# Tag #
#######

struct Tag{F,V}
end
struct Tag{F,V} <: AbstractTag{F,V} end

const TAGCOUNT = Threads.Atomic{UInt}(0)

Expand All @@ -20,11 +18,13 @@ end

Tag(::Nothing, ::Type{V}) where {V} = nothing


@inline function ≺(::Type{Tag{F1,V1}}, ::Type{Tag{F2,V2}}) where {F1,V1,F2,V2}
tagcount(Tag{F1,V1}) < tagcount(Tag{F2,V2})
end

#default implementation of maketag.
maketag(f::F,::Type{V}) where {F,V} = Tag(f,V)

struct InvalidTagException{E,O} <: Exception
end

Expand All @@ -38,8 +38,16 @@ checktag(::Type{Tag{F,V}}, f::F, x::AbstractArray{V}) where {F,V} = true

# no easy way to check Jacobian tag used with Hessians as multiple functions may be used
checktag(::Type{Tag{FT,VT}}, f::F, x::AbstractArray{V}) where {FT<:Tuple,VT,F,V} = true
checktag(::Type{AbstractTag{FT,VT}}, f::F, x::AbstractArray{V}) where {FT<:Tuple,VT,F,V} = true

#AbstractTag support
function checktag(T::Type{AbstractTag{FT,VT}}, f::F, x::AbstractArray{V}) where {FT,VT,F,V}
T2 = maketagtype(f,V) #maketag(f::F,type{V})::AbstractTag{F2,V2} is not equivalent to Tag{F,V}
T2 !== T && throw(InvalidTagException{T,T2}())
return true
end

# custom tag: you're on your own.
# custom tag not in the tag API: you're on your own.
checktag(z, f, x) = true


Expand Down Expand Up @@ -82,7 +90,7 @@ This constructor does not store/modify `y` or `x`.
function DerivativeConfig(f::F,
y::AbstractArray{Y},
x::X,
tag::T = Tag(f, X)) where {F,X<:Real,Y<:Real,T}
tag::T = maketag(f, X)) where {F,X<:Real,Y<:Real,T}
duals = similar(y, Dual{T,Y,1})
return DerivativeConfig{T,typeof(duals)}(duals)
end
Expand Down Expand Up @@ -117,7 +125,7 @@ This constructor does not store/modify `x`.
function GradientConfig(f::F,
x::AbstractArray{V},
::Chunk{N} = Chunk(x),
::T = Tag(f, V)) where {F,V,N,T}
::T = maketag(f, V)) where {F,V,N,T}
seeds = construct_seeds(Partials{N,V})
duals = similar(x, Dual{T,V,N})
return GradientConfig{T,V,N,typeof(duals)}(seeds, duals)
Expand Down Expand Up @@ -154,7 +162,7 @@ This constructor does not store/modify `x`.
function JacobianConfig(f::F,
x::AbstractArray{V},
::Chunk{N} = Chunk(x),
::T = Tag(f, V)) where {F,V,N,T}
::T = maketag(f, V)) where {F,V,N,T}
seeds = construct_seeds(Partials{N,V})
duals = similar(x, Dual{T,V,N})
return JacobianConfig{T,V,N,typeof(duals)}(seeds, duals)
Expand All @@ -180,7 +188,7 @@ function JacobianConfig(f::F,
y::AbstractArray{Y},
x::AbstractArray{X},
::Chunk{N} = Chunk(x),
::T = Tag(f, X)) where {F,Y,X,N,T}
::T = maketag(f, X)) where {F,Y,X,N,T}
seeds = construct_seeds(Partials{N,X})
yduals = similar(y, Dual{T,Y,N})
xduals = similar(x, Dual{T,X,N})
Expand Down Expand Up @@ -221,7 +229,7 @@ This constructor does not store/modify `x`.
function HessianConfig(f::F,
x::AbstractArray{V},
chunk::Chunk = Chunk(x),
tag = Tag(f, V)) where {F,V}
tag = maketag(f, V)) where {F,V}
jacobian_config = JacobianConfig(f, x, chunk, tag)
gradient_config = GradientConfig(f, jacobian_config.duals, chunk, tag)
return HessianConfig(jacobian_config, gradient_config)
Expand All @@ -246,7 +254,7 @@ function HessianConfig(f::F,
result::DiffResult,
x::AbstractArray{V},
chunk::Chunk = Chunk(x),
tag = Tag(f, V)) where {F,V}
tag = maketag(f, V)) where {F,V}
jacobian_config = JacobianConfig((f,gradient), DiffResults.gradient(result), x, chunk, tag)
gradient_config = GradientConfig(f, jacobian_config.duals[2], chunk, tag)
return HessianConfig(jacobian_config, gradient_config)
Expand Down
4 changes: 2 additions & 2 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Return `df/dx` evaluated at `x`, assuming `f` is called as `f(x)`.
This method assumes that `isa(f(x), Union{Real,AbstractArray})`.
"""
@inline function derivative(f::F, x::R) where {F,R<:Real}
T = typeof(Tag(f, R))
T = maketagtype(f,R)
return extract_derivative(T, f(Dual{T}(x, one(x))))
end

Expand Down Expand Up @@ -44,7 +44,7 @@ This method assumes that `isa(f(x), Union{Real,AbstractArray})`.
@inline function derivative!(result::Union{AbstractArray,DiffResult},
f::F, x::R) where {F,R<:Real}
result isa DiffResult || require_one_based_indexing(result)
T = typeof(Tag(f, R))
T = maketagtype(f,R)
ydual = f(Dual{T}(x, one(x)))
result = extract_value!(T, result, ydual)
result = extract_derivative!(T, result, ydual)
Expand Down
19 changes: 19 additions & 0 deletions src/prelude.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,22 @@ function qualified_cse!(expr)
end
return cse_expr
end

#=
######
# AbstractTag interface

Required definitions:
- ≺ (between two AbstractTags of the same type)
- maketag(f,::Type{V}) where {V <: Real}

Optional definitions:
- ≺ (between two AbstractTags of the of different type)
- maketagtype(f,::Type{V}) where {V <: Real} (default: defined in terms of maketag)
- checktag(tag::MyTagType,f,x)
######
=#
abstract type AbstractTag{F,V} end

maketag(::Nothing,::Type{V}) where {V} = nothing
@inline maketagtype(f::F,::Type{V}) where {F,V} = typeof(maketag(f,V))
Loading