Skip to content

Commit 45f9286

Browse files
committed
inference: improve management of non-type parameters
Prevent occurrence of v or Type{v} in the type-lattice, where v is not a Type (or TypeVar). Fixes #42646, and similar problems from code-reading.
1 parent 50fcb03 commit 45f9286

File tree

7 files changed

+99
-58
lines changed

7 files changed

+99
-58
lines changed

base/compiler/abstractinterpretation.jl

+13-11
Original file line numberDiff line numberDiff line change
@@ -783,26 +783,27 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
783783
end
784784
if isa(tti, Union)
785785
utis = uniontypes(tti)
786-
if _any(t -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
786+
if _any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
787787
return Any[Vararg{Any}], nothing
788788
end
789-
result = Any[rewrap_unionall(p, tti0) for p in (utis[1]::DataType).parameters]
790-
for t::DataType in utis[2:end]
791-
if length(t.parameters) != length(result)
789+
ltp = length((utis[1]::DataType).parameters)
790+
for t in utis
791+
if length((t::DataType).parameters) != ltp
792792
return Any[Vararg{Any}], nothing
793793
end
794-
for j in 1:length(t.parameters)
795-
result[j] = tmerge(result[j], rewrap_unionall(t.parameters[j], tti0))
794+
end
795+
result = Any[ Union{} for _ in 1:ltp ]
796+
for t in utis
797+
tps = (t::DataType).parameters
798+
_any(@nospecialize(t) -> !isa(t, Type), tps) && continue
799+
for j in 1:ltp
800+
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
796801
end
797802
end
798803
return result, nothing
799804
elseif tti0 <: Tuple
800805
if isa(tti0, DataType)
801-
if isvatuple(tti0) && length(tti0.parameters) == 1
802-
return Any[Vararg{unwrapva(tti0.parameters[1])}], nothing
803-
else
804-
return Any[ p for p in tti0.parameters ], nothing
805-
end
806+
return Any[ p for p in tti0.parameters ], nothing
806807
elseif !isa(tti, DataType)
807808
return Any[Vararg{Any}], nothing
808809
else
@@ -1098,6 +1099,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
10981099
tty_lb = tty_ub # TODO: this would be wrong if !isexact_tty, but instanceof_tfunc doesn't preserve this info
10991100
if !has_free_typevars(tty_lb) && !has_free_typevars(tty_ub)
11001101
ifty = typeintersect(aty, tty_ub)
1102+
valid_as_lattice(ifty) || (ifty = Union{})
11011103
elty = typesubtract(aty, tty_lb, InferenceParams(interp).MAX_UNION_SPLITTING)
11021104
return Conditional(a, ifty, elty)
11031105
end

base/compiler/inferenceresult.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
6262
if !toplevel && isva
6363
if specTypes == Tuple
6464
if nargs > 1
65-
linfo_argtypes = svec(Any[Any for i = 1:(nargs - 1)]..., Tuple.parameters[1])
65+
linfo_argtypes = Any[Any for i = 1:nargs]
66+
linfo_argstypes[end] = Vararg{Any}
6667
end
6768
vargtype = Tuple
6869
else
@@ -77,9 +78,10 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
7778
end
7879
else
7980
vargtype_elements = Any[]
80-
for p in linfo_argtypes[nargs:linfo_argtypes_length]
81+
for i in nargs:linfo_argtypes_length
82+
p = linfo_argtypes[i]
8183
p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p)
82-
push!(vargtype_elements, elim_free_typevars(rewrap(p, specTypes)))
84+
push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes)))
8385
end
8486
for i in 1:length(vargtype_elements)
8587
atyp = vargtype_elements[i]
@@ -118,7 +120,7 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
118120
elseif isconstType(atyp)
119121
atyp = Const(atyp.parameters[1])
120122
else
121-
atyp = elim_free_typevars(rewrap(atyp, specTypes))
123+
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
122124
end
123125
i == n && (lastatype = atyp)
124126
cache_argtypes[i] = atyp

base/compiler/tfuncs.jl

+40-27
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ add_tfunc(throw, 1, 1, (@nospecialize(x)) -> Bottom, 0)
6767
# if istype is true, the actual runtime value will definitely be a type (e.g. this is false for Union{Type{Int}, Int})
6868
function instanceof_tfunc(@nospecialize(t))
6969
if isa(t, Const)
70-
if isa(t.val, Type)
70+
if isa(t.val, Type) && valid_as_lattice(t.val)
7171
return t.val, true, isconcretetype(t.val), true
7272
end
7373
return Bottom, true, false, false # runtime throws on non-Type
@@ -79,6 +79,7 @@ function instanceof_tfunc(@nospecialize(t))
7979
return Bottom, true, false, false # literal Bottom or non-Type
8080
elseif isType(t)
8181
tp = t.parameters[1]
82+
valid_as_lattice(tp) || return Bottom, true, false, false # runtime unreachable / throws on non-Type
8283
return tp, !has_free_typevars(tp), isconcretetype(tp), true
8384
elseif isa(t, UnionAll)
8485
t′ = unwrap_unionall(t)
@@ -473,7 +474,8 @@ function pointer_eltype(@nospecialize(ptr))
473474
unw = unwrap_unionall(a)
474475
if isa(unw, DataType) && unw.name === Ptr.body.name
475476
T = unw.parameters[1]
476-
T isa Type && return rewrap_unionall(T, a)
477+
valid_as_lattice(T) || return Bottom
478+
return rewrap_unionall(T, a)
477479
end
478480
end
479481
return Any
@@ -486,7 +488,8 @@ function atomic_pointermodify_tfunc(ptr, op, v, order)
486488
if isa(unw, DataType) && unw.name === Ptr.body.name
487489
T = unw.parameters[1]
488490
# note: we could sometimes refine this to a PartialStruct if we analyzed `op(T, T)::T`
489-
T isa Type && return rewrap_unionall(Pair{T, T}, a)
491+
valid_as_lattice(T) || return Bottom
492+
return rewrap_unionall(Pair{T, T}, a)
490493
end
491494
end
492495
return Pair
@@ -498,7 +501,8 @@ function atomic_pointerreplace_tfunc(ptr, x, v, success_order, failure_order)
498501
unw = unwrap_unionall(a)
499502
if isa(unw, DataType) && unw.name === Ptr.body.name
500503
T = unw.parameters[1]
501-
T isa Type && return rewrap_unionall(ccall(:jl_apply_cmpswap_type, Any, (Any,), T), a)
504+
valid_as_lattice(T) || return Bottom
505+
return rewrap_unionall(ccall(:jl_apply_cmpswap_type, Any, (Any,), T), a)
502506
end
503507
end
504508
return ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T
@@ -754,8 +758,8 @@ function getfield_nothrow(@nospecialize(s00), @nospecialize(name), boundscheck::
754758
s0 = widenconst(s00)
755759
s = unwrap_unionall(s0)
756760
if isa(s, Union)
757-
return getfield_nothrow(rewrap(s.a, s00), name, boundscheck) &&
758-
getfield_nothrow(rewrap(s.b, s00), name, boundscheck)
761+
return getfield_nothrow(rewrap_unionall(s.a, s00), name, boundscheck) &&
762+
getfield_nothrow(rewrap_unionall(s.b, s00), name, boundscheck)
759763
elseif isa(s, DataType)
760764
# Can't say anything about abstract types
761765
isabstracttype(s) && return false
@@ -782,8 +786,8 @@ getfield_tfunc(s00, name, order, boundscheck) = (@nospecialize; getfield_tfunc(s
782786
function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
783787
s = unwrap_unionall(s00)
784788
if isa(s, Union)
785-
return tmerge(getfield_tfunc(rewrap(s.a,s00), name),
786-
getfield_tfunc(rewrap(s.b,s00), name))
789+
return tmerge(getfield_tfunc(rewrap_unionall(s.a, s00), name),
790+
getfield_tfunc(rewrap_unionall(s.b, s00), name))
787791
elseif isa(s, Conditional)
788792
return Bottom # Bool has no fields
789793
elseif isa(s, Const) || isconstType(s)
@@ -857,9 +861,6 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
857861
end
858862
return Any
859863
end
860-
# If no value has this type, then this statement should be unreachable.
861-
# Bail quickly now.
862-
has_concrete_subtype(s) || return Union{}
863864
if s.name === _NAMEDTUPLE_NAME && !isconcretetype(s)
864865
if isa(name, Const) && isa(name.val, Symbol)
865866
if isa(s.parameters[1], Tuple)
@@ -878,7 +879,9 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
878879
return getfield_tfunc(_ts, name)
879880
end
880881
ftypes = datatype_fieldtypes(s)
881-
if isempty(ftypes)
882+
# If no value has this type, then this statement should be unreachable.
883+
# Bail quickly now.
884+
if !has_concrete_subtype(s) || isempty(ftypes)
882885
return Bottom
883886
end
884887
if isa(name, Conditional)
@@ -1072,8 +1075,8 @@ function fieldtype_tfunc(@nospecialize(s0), @nospecialize(name))
10721075

10731076
su = unwrap_unionall(s0)
10741077
if isa(su, Union)
1075-
return tmerge(fieldtype_tfunc(rewrap(su.a, s0), name),
1076-
fieldtype_tfunc(rewrap(su.b, s0), name))
1078+
return tmerge(fieldtype_tfunc(rewrap_unionall(su.a, s0), name),
1079+
fieldtype_tfunc(rewrap_unionall(su.b, s0), name))
10771080
end
10781081

10791082
s, exact = instanceof_tfunc(s0)
@@ -1085,8 +1088,8 @@ function _fieldtype_tfunc(@nospecialize(s), exact::Bool, @nospecialize(name))
10851088
exact = exact && !has_free_typevars(s)
10861089
u = unwrap_unionall(s)
10871090
if isa(u, Union)
1088-
ta0 = _fieldtype_tfunc(rewrap(u.a, s), exact, name)
1089-
tb0 = _fieldtype_tfunc(rewrap(u.b, s), exact, name)
1091+
ta0 = _fieldtype_tfunc(rewrap_unionall(u.a, s), exact, name)
1092+
tb0 = _fieldtype_tfunc(rewrap_unionall(u.b, s), exact, name)
10901093
ta0 tb0 && return tb0
10911094
tb0 ta0 && return ta0
10921095
ta, exacta, _, istypea = instanceof_tfunc(ta0)
@@ -1296,7 +1299,11 @@ function apply_type_tfunc(@nospecialize(headtypetype), @nospecialize args...)
12961299
end
12971300
end
12981301
end
1299-
largs == 1 && return isa(args[1], Type) ? typeintersect(args[1], Type) : Type
1302+
if largs == 1 # Union{T} --> T
1303+
u1 = typeintersect(widenconst(args[1]), Type)
1304+
valid_as_lattice(u1) || return Bottom
1305+
return u1
1306+
end
13001307
hasnonType && return Type
13011308
ty = Union{}
13021309
allconst = true
@@ -1471,21 +1478,26 @@ end
14711478

14721479
function arrayref_tfunc(@nospecialize(boundscheck), @nospecialize(a), @nospecialize i...)
14731480
a = widenconst(a)
1474-
if a <: Array
1475-
if isa(a, DataType) && isa(a.parameters[1], Type)
1476-
return a.parameters[1]
1477-
elseif isa(a, UnionAll) && !has_free_typevars(a)
1478-
unw = unwrap_unionall(a)
1479-
if isa(unw, DataType)
1480-
return rewrap_unionall(unw.parameters[1], a)
1481-
end
1481+
if !has_free_typevars(a) && a <: Array
1482+
a0 = a
1483+
if isa(a, UnionAll)
1484+
a = unwrap_unionall(a0)
1485+
end
1486+
if isa(a, DataType)
1487+
T = a.parameters[1]
1488+
valid_as_lattice(T) || return Bottom
1489+
return rewrap_unionall(T, a0)
14821490
end
14831491
end
14841492
return Any
14851493
end
14861494
add_tfunc(arrayref, 3, INT_INF, arrayref_tfunc, 20)
14871495
add_tfunc(const_arrayref, 3, INT_INF, arrayref_tfunc, 20)
1488-
add_tfunc(arrayset, 4, INT_INF, (@nospecialize(boundscheck), @nospecialize(a), @nospecialize(v), @nospecialize i...)->a, 20)
1496+
function arrayset_tfunc(@nospecialize(boundscheck), @nospecialize(a), @nospecialize(v), @nospecialize i...)
1497+
# TODO: we could check that the type-intersect of arrayref_tfunc and v is non-empty or always throws
1498+
return a
1499+
end
1500+
add_tfunc(arrayset, 4, INT_INF, arrayset_tfunc, 20)
14891501

14901502
function _opaque_closure_tfunc(@nospecialize(arg), @nospecialize(isva),
14911503
@nospecialize(lb), @nospecialize(ub), @nospecialize(source), env::Vector{Any},
@@ -1508,6 +1520,7 @@ function _opaque_closure_tfunc(@nospecialize(arg), @nospecialize(isva),
15081520
return PartialOpaque(t, tuple_tfunc(env), isva.val, linfo, source.val)
15091521
end
15101522

1523+
# whether getindex for the elements can potentially throw UndefRef
15111524
function array_type_undefable(@nospecialize(a))
15121525
if isa(a, Union)
15131526
return array_type_undefable(a.a) || array_type_undefable(a.b)
@@ -1550,7 +1563,7 @@ function _builtin_nothrow(@nospecialize(f), argtypes::Array{Any,1}, @nospecializ
15501563
# Check that we can determine the element type
15511564
(isa(a, DataType) && isa(a.parameters[1], Type)) || return false
15521565
# Check that the element type is compatible with the element we're assigning
1553-
(argtypes[3] a.parameters[1]::Type) || return false
1566+
(argtypes[3] a.parameters[1]) || return false
15541567
return true
15551568
elseif f === arrayref || f === const_arrayref
15561569
return array_builtin_common_nothrow(argtypes, 3)

base/compiler/typelimits.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,7 @@ function tmeet(@nospecialize(v), @nospecialize(t))
605605
return v
606606
end
607607
ti = typeintersect(widev, t)
608-
if ti === Bottom
609-
return Bottom
610-
end
608+
valid_as_lattice(ti) || return Bottom
611609
@assert widev <: Tuple
612610
new_fields = Vector{Any}(undef, length(v.fields))
613611
for i = 1:length(new_fields)
@@ -628,5 +626,7 @@ function tmeet(@nospecialize(v), @nospecialize(t))
628626
end
629627
return v
630628
end
631-
return typeintersect(widenconst(v), t)
629+
ti = typeintersect(widenconst(v), t)
630+
valid_as_lattice(ti) || return Bottom
631+
return ti
632632
end

base/compiler/typeutils.jl

+28-11
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,6 @@
44
# lattice utilities #
55
#####################
66

7-
function rewrap(@nospecialize(t), @nospecialize(u))
8-
if isa(t, TypeVar) || isa(t, Type) || isvarargtype(t)
9-
return rewrap_unionall(t, u)
10-
end
11-
return t
12-
end
13-
147
isType(@nospecialize t) = isa(t, DataType) && t.name === _TYPE_NAME
158

169
# true if Type{T} is inlineable as constant T
@@ -42,8 +35,6 @@ end
4235

4336
has_const_info(@nospecialize x) = (!isa(x, Type) && !isvarargtype(x)) || isType(x)
4437

45-
has_concrete_subtype(d::DataType) = d.flags & 0x20 == 0x20
46-
4738
# Subtyping currently intentionally answers certain queries incorrectly for kind types. For
4839
# some of these queries, this check can be used to somewhat protect against making incorrect
4940
# decisions based on incorrect subtyping. Note that this check, itself, is broken for
@@ -89,6 +80,30 @@ function datatype_min_ninitialized(t::DataType)
8980
return length(t.name.names) - t.name.n_uninitialized
9081
end
9182

83+
has_concrete_subtype(d::DataType) = d.flags & 0x20 == 0x20 # n.b. often computed only after setting the type and layout fields
84+
85+
# determine whether x is a valid lattice element tag
86+
# For example, Type{v} is not valid if v is a value
87+
# Accepts TypeVars also, since it assumes the user will rewrap it correctly
88+
function valid_as_lattice(@nospecialize(x))
89+
x === Bottom && false
90+
x isa TypeVar && return valid_as_lattice(x.ub)
91+
x isa UnionAll && (x = unwrap_unionall(x))
92+
if x isa Union
93+
# the Union constructor ensures this (and we'll recheck after
94+
# operations that might remove the Union itself)
95+
return true
96+
end
97+
if x isa DataType
98+
if isType(x)
99+
p = x.parameters[1]
100+
p isa Type || p isa TypeVar || return false
101+
end
102+
return true
103+
end
104+
return false
105+
end
106+
92107
# test if non-Type, non-TypeVar `x` can be used to parameterize a type
93108
function valid_tparam(@nospecialize(x))
94109
if isa(x, Tuple)
@@ -119,8 +134,10 @@ function typesubtract(@nospecialize(a), @nospecialize(b), MAX_UNION_SPLITTING::I
119134
end
120135
ua = unwrap_unionall(a)
121136
if isa(ua, Union)
122-
return Union{typesubtract(rewrap_unionall(ua.a, a), b, MAX_UNION_SPLITTING),
123-
typesubtract(rewrap_unionall(ua.b, a), b, MAX_UNION_SPLITTING)}
137+
uua = typesubtract(rewrap_unionall(ua.a, a), b, MAX_UNION_SPLITTING)
138+
uub = typesubtract(rewrap_unionall(ua.b, a), b, MAX_UNION_SPLITTING)
139+
return Union{valid_as_lattice(uua) ? uua : Union{},
140+
valid_as_lattice(uub) ? uub : Union{}}
124141
elseif a isa DataType
125142
ub = unwrap_unionall(b)
126143
if ub isa DataType

src/jltypes.c

+5-1
Original file line numberDiff line numberDiff line change
@@ -1198,8 +1198,12 @@ void jl_precompute_memoized_dt(jl_datatype_t *dt, int cacheable)
11981198
dt->has_concrete_subtype = 0;
11991199
}
12001200
}
1201-
if (dt->name == jl_type_typename)
1201+
if (dt->name == jl_type_typename) {
12021202
cacheable = 0; // the cache for Type ignores parameter normalization, so it can't be used as a regular hash
1203+
jl_value_t *p = jl_tparam(dt, 0);
1204+
if (!jl_is_type(p) && !jl_is_typevar(p)) // Type{v} has no subtypes, if v is not a Type
1205+
dt->has_concrete_subtype = 0;
1206+
}
12031207
dt->hash = typekey_hash(dt->name, jl_svec_data(dt->parameters), l, cacheable);
12041208
dt->cached_by_hash = cacheable ? (typekey_hash(dt->name, jl_svec_data(dt->parameters), l, 0) != 0) : (dt->hash != 0);
12051209
}

test/compiler/inference.jl

+3
Original file line numberDiff line numberDiff line change
@@ -3583,3 +3583,6 @@ let
35833583
@test argtypes[10] == Any
35843584
@test argtypes[11] == Tuple{Integer,Integer}
35853585
end
3586+
3587+
# issue #42646
3588+
@test only(Base.return_types(getindex, (Array{undef}, Int))) >: Union{} # check that it does not throw

0 commit comments

Comments
 (0)