Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: towards transitive specificity #30171

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,14 +375,18 @@ getindex(::Type{T}, x) where {T} = (@_inline_meta; a = Vector{T}(undef, 1); @inb
getindex(::Type{T}, x, y) where {T} = (@_inline_meta; a = Vector{T}(undef, 2); @inbounds (a[1] = x; a[2] = y); a)
getindex(::Type{T}, x, y, z) where {T} = (@_inline_meta; a = Vector{T}(undef, 3); @inbounds (a[1] = x; a[2] = y; a[3] = z); a)

getindex(::Type{Any}) = Vector{Any}()
getindex(::Type{Any}, @nospecialize(x)) = (a = Vector{Any}(undef, 1); @inbounds a[1] = x; a)
getindex(::Type{Any}, @nospecialize(x), @nospecialize(y)) = (a = Vector{Any}(undef, 2); @inbounds (a[1] = x; a[2] = y); a)
getindex(::Type{Any}, @nospecialize(x), @nospecialize(y), @nospecialize(z)) = (a = Vector{Any}(undef, 3); @inbounds (a[1] = x; a[2] = y; a[3] = z); a)

function getindex(::Type{Any}, @nospecialize vals...)
a = Vector{Any}(undef, length(vals))
@inbounds for i = 1:length(vals)
a[i] = vals[i]
end
return a
end
getindex(::Type{Any}) = Vector{Any}()

function fill!(a::Union{Array{UInt8}, Array{Int8}}, x::Integer)
ccall(:memset, Ptr{Cvoid}, (Ptr{Cvoid}, Cint, Csize_t), a, x, length(a))
Expand Down
4 changes: 3 additions & 1 deletion base/errorshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ function showerror(io::IO, ex::TypeError)
if ex.expected === Bool
print(io, "non-boolean (", typeof(ex.got), ") used in boolean context")
else
if isa(ex.got, Type)
if isvarargtype(ex.got)
targs = (ex.got,)
elseif isa(ex.got, Type)
targs = ("Type{", ex.got, "}")
else
targs = (typeof(ex.got),)
Expand Down
6 changes: 4 additions & 2 deletions base/indices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@ getindex(S::Slice, i::Int) = (@_inline_meta; @boundscheck checkbounds(S, i); i)
getindex(S::Slice, i::AbstractUnitRange{<:Integer}) = (@_inline_meta; @boundscheck checkbounds(S, i); i)
getindex(S::Slice, i::StepRange{<:Integer}) = (@_inline_meta; @boundscheck checkbounds(S, i); i)
show(io::IO, r::Slice) = print(io, "Base.Slice(", r.indices, ")")
iterate(S::Slice, s...) = iterate(S.indices, s...)
iterate(S::Slice) = iterate(S.indices)
iterate(S::Slice, s) = iterate(S.indices, s)

"""
LinearIndices(A::AbstractArray)
Expand Down Expand Up @@ -419,7 +420,8 @@ function getindex(iter::LinearIndices, i::AbstractRange{<:Integer})
end
# More efficient iteration — predominantly for non-vector LinearIndices
# but one-dimensional LinearIndices must be special-cased to support OffsetArrays
iterate(iter::LinearIndices{1}, s...) = iterate(axes1(iter.indices[1]), s...)
iterate(iter::LinearIndices{1}) = iterate(axes1(iter.indices[1]))
iterate(iter::LinearIndices{1}, s) = iterate(axes1(iter.indices[1]), s)
iterate(iter::LinearIndices, i=1) = i > length(iter) ? nothing : (i, i+1)

# Needed since firstindex and lastindex are defined in terms of LinearIndices
Expand Down
1 change: 1 addition & 0 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ function promote(x, y, z, a...)
p
end

promote(x::T, y::T) where {T} = (x, y)
promote(x::T, y::T, zs::T...) where {T} = (x, y, zs...)

not_sametype(x::T, y::T) where {T} = sametype_error(x)
Expand Down
2 changes: 1 addition & 1 deletion contrib/generate_precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ function generate_precompile_statements()
# println(statement)
# Work around #28808
occursin("\"YYYY-mm-dd\\THH:MM:SS\"", statement) && continue
statement == "precompile(Tuple{typeof(Base.show), Base.IOContext{Base.TTY}, Type{Vararg{Any, N} where N}})" && continue
occursin("Type{Vararg", statement) && continue
try
Base.include_string(PrecompileStagingArea, statement)
catch
Expand Down
136 changes: 43 additions & 93 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -2552,8 +2552,13 @@ static int tuple_morespecific(jl_datatype_t *cdt, jl_datatype_t *pdt, int invari
jl_vararg_kind_t ckind = jl_vararg_kind(clast);
int cva = ckind > JL_VARARG_INT;
int pva = jl_vararg_kind(jl_tparam(pdt,plen-1)) > JL_VARARG_INT;
if (!cva && !pva && clen > plen)
return 0;
if (cva && !pva && clen > plen+1)
return 0;
int cdiag = 0, pdiag = 0;
int some_morespecific = 0;
int va_morespecific = 0; // c vararg type is more specific than all p types
while (1) {
if (cva && pva && i >= clen && i >= plen)
break;
Expand Down Expand Up @@ -2588,78 +2593,27 @@ static int tuple_morespecific(jl_datatype_t *cdt, jl_datatype_t *pdt, int invari
int cms = type_morespecific_(ce, pe, invariant, env);
int eqv = !cms && eq_msp(ce, pe, env);

if (!cms && !eqv && !sub_msp(ce, pe, env)) {
/*
A bound vararg tuple can be more specific despite disjoint elements in order to
preserve transitivity. For example in
A = Tuple{Array{T,N}, Vararg{Int,N}} where {T,N}
B = Tuple{Array, Int}
C = Tuple{AbstractArray, Int, Array}
we need A < B < C and A < C.
*/
return some_morespecific && cva && ckind == JL_VARARG_BOUND && num_occurs((jl_tvar_t*)jl_tparam1(jl_unwrap_unionall(clast)), env) > 1;
}
if (!cms && !eqv && !sub_msp(ce, pe, env))
return 0;

// Tuple{..., T} not more specific than Tuple{..., Vararg{S}} if S is diagonal
if (eqv && i == clen-1 && clen == plen && !cva && pva && jl_is_typevar(ce) && jl_is_typevar(pe) && !cdiag && pdiag)
return 0;

if (cms && cva && i == clen-1)
va_morespecific = 1;

if (cms) some_morespecific = 1;
i++;
}
if (cva && pva && clen > plen && (!pdiag || cdiag))
return 1;
if (cva && !pva && !some_morespecific)
if (cva && !pva && !va_morespecific)
// ambiguity: c is more specific in type but p is more specific in count
return 0;
return some_morespecific || (cdiag && !pdiag);
}

static size_t tuple_full_length(jl_value_t *t)
{
size_t n = jl_nparams(t);
if (n == 0) return 0;
jl_value_t *last = jl_unwrap_unionall(jl_tparam(t,n-1));
if (jl_is_vararg_type(last)) {
jl_value_t *N = jl_tparam1(last);
if (jl_is_long(N))
n += jl_unbox_long(N)-1;
}
return n;
}

// Called when a is a bound-vararg and b is not a vararg. Sets the vararg length
// in a to match b, as long as this makes some earlier argument more specific.
static int args_morespecific_fix1(jl_value_t *a, jl_value_t *b, int swap, jl_typeenv_t *env)
{
size_t n = jl_nparams(a);
int taillen = tuple_full_length(b)-n+1;
if (taillen <= 0)
return -1;
assert(jl_is_va_tuple((jl_datatype_t*)a));
jl_datatype_t *new_a = NULL;
jl_value_t *e[2] = { jl_tparam1(jl_unwrap_unionall(jl_tparam(a, n-1))), jl_box_long(taillen) };
JL_GC_PUSH2(&new_a, &e[1]);
new_a = (jl_datatype_t*)jl_instantiate_type_with((jl_value_t*)a, e, 1);
int changed = 0;
for (size_t i = 0; i < n-1; i++) {
if (jl_tparam(a, i) != jl_tparam(new_a, i)) {
changed = 1;
break;
}
}
int ret = -1;
if (changed) {
if (eq_msp(b, (jl_value_t*)new_a, env))
ret = swap;
else if (swap)
ret = type_morespecific_(b, (jl_value_t*)new_a, 0, env);
else
ret = type_morespecific_((jl_value_t*)new_a, b, 0, env);
}
JL_GC_POP();
return ret;
}

static int count_occurs(jl_value_t *t, jl_tvar_t *v)
{
if (t == (jl_value_t*)v)
Expand Down Expand Up @@ -2698,33 +2652,7 @@ static int type_morespecific_(jl_value_t *a, jl_value_t *b, int invariant, jl_ty
if (a == b)
return 0;

if (jl_is_unionall(a)) {
jl_unionall_t *ua = (jl_unionall_t*)a;
jl_typeenv_t newenv = { ua->var, 0x0, env };
newenv.val = (jl_value_t*)(intptr_t)count_occurs(ua->body, ua->var);
return type_morespecific_(ua->body, b, invariant, &newenv);
}
if (jl_is_unionall(b)) {
jl_unionall_t *ub = (jl_unionall_t*)b;
jl_typeenv_t newenv = { ub->var, 0x0, env };
newenv.val = (jl_value_t*)(intptr_t)count_occurs(ub->body, ub->var);
return type_morespecific_(a, ub->body, invariant, &newenv);
}

if (jl_is_tuple_type(a) && jl_is_tuple_type(b)) {
// When one is JL_VARARG_BOUND and the other has fixed length,
// allow the argument length to fix the tvar
jl_vararg_kind_t akind = jl_va_tuple_kind((jl_datatype_t*)a);
jl_vararg_kind_t bkind = jl_va_tuple_kind((jl_datatype_t*)b);
int ans = -1;
if (akind == JL_VARARG_BOUND && bkind < JL_VARARG_BOUND) {
ans = args_morespecific_fix1(a, b, 0, env);
if (ans == 1) return 1;
}
if (bkind == JL_VARARG_BOUND && akind < JL_VARARG_BOUND) {
ans = args_morespecific_fix1(b, a, 1, env);
if (ans == 0) return 0;
}
return tuple_morespecific((jl_datatype_t*)a, (jl_datatype_t*)b, invariant, env);
}

Expand Down Expand Up @@ -2797,13 +2725,20 @@ static int type_morespecific_(jl_value_t *a, jl_value_t *b, int invariant, jl_ty
for(size_t i=0; i < jl_nparams(tta); i++) {
jl_value_t *apara = jl_tparam(tta,i);
jl_value_t *bpara = jl_tparam(ttb,i);
if (!jl_has_free_typevars(apara) && !jl_has_free_typevars(bpara) &&
!jl_types_equal(apara, bpara))
int afree = jl_has_free_typevars(apara);
int bfree = jl_has_free_typevars(bpara);
if (!afree && !bfree && !jl_types_equal(apara, bpara))
return 0;
if (type_morespecific_(apara, bpara, 1, env))
if (type_morespecific_(apara, bpara, 1, env) && (jl_is_typevar(apara) || !afree || bfree))
ascore += 1;
else if (type_morespecific_(bpara, apara, 1, env))
else if (type_morespecific_(bpara, apara, 1, env) && (jl_is_typevar(bpara) || !bfree || afree))
bscore += 1;
else if (eq_msp(apara, bpara, env)) {
if (!afree && bfree)
ascore += 1;
else if (afree && !bfree)
bscore += 1;
}
if (jl_is_typevar(bpara) && !jl_is_typevar(apara) && !jl_is_type(apara))
ascore1 = 1;
else if (jl_is_typevar(apara) && !jl_is_typevar(bpara) && !jl_is_type(bpara))
Expand All @@ -2829,9 +2764,6 @@ static int type_morespecific_(jl_value_t *a, jl_value_t *b, int invariant, jl_ty
return 0;
return ascore > bscore || adiag > bdiag;
}
else if (invariant) {
return 0;
}
tta = tta->super; super = 1;
}
return 0;
Expand All @@ -2854,8 +2786,9 @@ static int type_morespecific_(jl_value_t *a, jl_value_t *b, int invariant, jl_ty
if (((jl_tvar_t*)a)->ub == jl_bottom_type)
return 1;
if (jl_has_free_typevars(b)) {
if (type_morespecific_(((jl_tvar_t*)a)->ub, b, 0, env) ||
eq_msp(((jl_tvar_t*)a)->ub, b, env))
if (type_morespecific_(((jl_tvar_t*)a)->ub, b, 0, env))
return 1;
if (eq_msp(((jl_tvar_t*)a)->ub, b, env))
return num_occurs((jl_tvar_t*)a, env) >= 2;
}
else {
Expand All @@ -2876,12 +2809,29 @@ static int type_morespecific_(jl_value_t *a, jl_value_t *b, int invariant, jl_ty
return num_occurs((jl_tvar_t*)b, env) < 2;
}
else {
if (obviously_disjoint(a, ((jl_tvar_t*)b)->ub, 1))
return 0;
if (type_morespecific_(((jl_tvar_t*)b)->ub, a, 0, env))
return 0;
return 1;
}
}
return type_morespecific_(a, (jl_value_t*)((jl_tvar_t*)b)->ub, 0, env);
}

if (jl_is_unionall(a)) {
jl_unionall_t *ua = (jl_unionall_t*)a;
jl_typeenv_t newenv = { ua->var, 0x0, env };
newenv.val = (jl_value_t*)(intptr_t)count_occurs(ua->body, ua->var);
return type_morespecific_(ua->body, b, invariant, &newenv);
}
if (jl_is_unionall(b)) {
jl_unionall_t *ub = (jl_unionall_t*)b;
jl_typeenv_t newenv = { ub->var, 0x0, env };
newenv.val = (jl_value_t*)(intptr_t)count_occurs(ub->body, ub->var);
return type_morespecific_(a, ub->body, invariant, &newenv);
}

return 0;
}

Expand Down
Loading