diff --git a/src/check_struct_fields.jl b/src/check_struct_fields.jl index 9bb9464..d90fd1a 100644 --- a/src/check_struct_fields.jl +++ b/src/check_struct_fields.jl @@ -46,24 +46,23 @@ function check_field_type_fully_specified( mod::Module, struct_name, field_name, typevars, field_type_expr; report_error, location, ) - print("TypeVars: ") - dump(typevars) - @show mod + @debug "Module:" mod + @debug "TypeVars:" typevars TypeObj = Base.eval(mod, quote $(field_type_expr) where {$(typevars...)} end) - @show TypeObj + @debug "Type:" TypeObj @assert TypeObj isa Type if isconcretetype(TypeObj) # The type is concrete, so it is fully specified. - @info "Type is concrete: $(TypeObj)" + @debug "Type is concrete: $(TypeObj)" return true end if typeof(TypeObj) == DataType # The type is a DataType, so it is fully specified. # Presumably, it is an abstract type like `Number` or `Any` - @info "Type is a DataType: $(TypeObj)" + @debug "Type is a DataType: $(TypeObj)" return true end if typeof(TypeObj) == Union @@ -77,10 +76,10 @@ function check_field_type_fully_specified( end @assert typeof(TypeObj) === UnionAll "$(TypeObj) is not a UnionAll. Got $(typeof(TypeObj))." - num_type_params = _count_unionall_parameters(TypeObj) - num_expr_args = _count_type_expr_params(field_type_expr) - # "Less than or equal to" in order to support literal values in the type expression. - # E.g.: The UnionAll `Array{<:Int, 1}` has 1 type arg but 2 params in the expression. + num_type_params = _count_unionall_free_parameters(TypeObj) + num_expr_args = _count_type_expr_params(mod, field_type_expr) + # "Less than or equal to" in order to support fully constrained parameters in the expr. + # E.g.: `Vector{T} where T<:Int` has 0 free type params but 1 param in the expression. success = num_type_params <= num_expr_args if report_error @assert success field_type_not_complete_message( @@ -101,7 +100,7 @@ function field_type_not_complete_message( typename = nameof(TypeObj) typevars = join(["T$i" for i in 1:(num_type_params - num_expr_args)], ", ") typestr = "$(field_type_expr)" - @show typestr + @debug "Type string:" typestr if occursin("}", typestr) typestr = replace(typestr, "}" => ", $(typevars)}") else @@ -132,23 +131,30 @@ function field_type_not_complete_message( """ end -function _count_unionall_parameters(TypeObj::UnionAll) +_count_unionall_free_parameters(@nospecialize(::Any)) = 0 +function _count_unionall_free_parameters(TypeObj::UnionAll) count = 0 while typeof(TypeObj) === UnionAll - count += 1 + # in `T<:ConcreteType` we don't consider `T` as a free parameter + count += !isconcretetype(TypeObj.var.ub) TypeObj = TypeObj.body end + # The parameters might themselves be `UnionAll` + for param in TypeObj.parameters + count += _count_unionall_free_parameters(param) + end return count end -_count_type_expr_params(s::Symbol) = 0 -function _count_type_expr_params(expr::Expr) + +_count_type_expr_params(mod::Module, @nospecialize(x)) = 0 +_count_type_expr_params(mod::Module, s::Symbol) = Int(!isdefined(mod, s)) +function _count_type_expr_params(mod::Module, expr::Expr) count = 0 - while expr.head === :where - count += length(expr.args) - 1 + if expr.head === :where expr = expr.args[1] - end - if expr.head == :curly - count += length(expr.args) - 1 + count = _count_type_expr_params(mod, expr) + elseif expr.head == :curly + count += sum(_count_type_expr_params(mod, arg) for arg in expr.args) end return count end diff --git a/test/check_struct_fields_tests.jl b/test/check_struct_fields_tests.jl index ed03c49..c0a703f 100644 --- a/test/check_struct_fields_tests.jl +++ b/test/check_struct_fields_tests.jl @@ -1,4 +1,3 @@ - @testitem "check field tests" begin # Concrete DataType @@ -17,6 +16,17 @@ :(struct S x::Vector{Int} end), :x, ) + @test field_is_fully_specified( + @__MODULE__, + :(struct S x::Vector{Pair{Dict{Int,String},Set{Bool}}} end), + :x, + ) + @test field_is_fully_specified( + @__MODULE__, + :(struct S{T1} x::Pair{Dict{Bool}, String} end), + :x, + ) + # Fully specified field type (DataType) @test field_is_fully_specified( @__MODULE__, @@ -28,6 +38,11 @@ :(mutable struct S{T<:Int} x::Vector{T} end), :x, ) + @test field_is_fully_specified( + @__MODULE__, + :(struct S{T<:Int} x::Dict{T,Int} end), + :x, + ) # Fully specified UnionAll field type @test field_is_fully_specified( @@ -67,6 +82,11 @@ :(struct S{T1,T2} x::Dict{T1,T2} end), :x, ) + @test field_is_fully_specified( + @__MODULE__, + :(struct S{T1,T2} x::Vector{Dict{T1,T2}} end), + :x, + ) @test field_is_fully_specified( @__MODULE__, :(struct S{T0,T1,T2} x::Dict{T1,T2} end), @@ -77,22 +97,48 @@ :(struct S{T1} x::Dict{T1,Int} end), :x, ) + @test field_is_fully_specified( + @__MODULE__, + :(struct S{T1} x::Dict{T1,T1} end), + :x, + ) @test field_is_fully_specified( @__MODULE__, :(struct S{T2} x::Dict{Int,T2} end), :x, ) + @test field_is_fully_specified( + @__MODULE__, + :(struct S{T2} x::Pair{Dict{Int,T2},T2} end), + :x, + ) + # where-clause on the field @test field_is_fully_specified( @__MODULE__, :(struct S{T1} x::Dict{T1, T2} where {T2<:Int} end), :x, ) + @test field_is_fully_specified( + @__MODULE__, + :(struct S{T1} x::Dict{T1, T2} where {T2<:Integer} end), + :x, + ) @test field_is_fully_specified( @__MODULE__, :(struct S{T1} x::Dict{T1,T2} where T2 end), :x, ) + @test field_is_fully_specified( + @__MODULE__, + :(struct S{T1} x::Pair{Dict{K, Int64}, String} where K end), + :x, + ) + @test field_is_fully_specified( + @__MODULE__, + :(struct S{T1} x::Pair{Dict{K, Int64} where K, String} end), + :x, + ) # False cases: @test false == field_is_fully_specified( @@ -110,6 +156,11 @@ :(struct S{T} x::Dict end), :x, ) + @test false == field_is_fully_specified( + @__MODULE__, + :(struct S{T1} x::Pair{Dict{Bool, Int64}} end), + :x, + ) # Not applicable (abstract type or no type at all): @test field_is_fully_specified( diff --git a/test/runtests.jl b/test/runtests.jl index af35024..be969d9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using FullySpecifiedFieldTypesStaticTests +using StructFieldParamsTesting using ReTestItems -ReTestItems.runtests(FullySpecifiedFieldTypesStaticTests) +runtests(StructFieldParamsTesting; nworkers=1)