Skip to content

Commit 8960c6c

Browse files
committed
fix some bugs in diagonal subtyping
fixes #31824, fixes #24166
1 parent 6421def commit 8960c6c

File tree

5 files changed

+105
-18
lines changed

5 files changed

+105
-18
lines changed

base/essentials.jl

+12-4
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,18 @@ convert(::Type{T}, x::T) where {T<:AtLeast1} = x
306306
convert(::Type{T}, x::AtLeast1) where {T<:AtLeast1} =
307307
(convert(tuple_type_head(T), x[1]), convert(tuple_type_tail(T), tail(x))...)
308308

309-
# converting to Vararg tuple types
310-
convert(::Type{Tuple{Vararg{V}}}, x::Tuple{Vararg{V}}) where {V} = x
311-
convert(T::Type{Tuple{Vararg{V}}}, x::Tuple) where {V} =
312-
(convert(tuple_type_head(T), x[1]), convert(T, tail(x))...)
309+
# converting to other tuple types, including Vararg tuple types
310+
_bad_tuple_conversion(T, x) = (@_noinline_meta; throw(MethodError(convert, (T, x))))
311+
function convert(::Type{T}, x::AtLeast1) where {T<:Tuple}
312+
if x isa T
313+
return x
314+
end
315+
y = (convert(tuple_type_head(T), x[1]), convert(tuple_type_tail(T), tail(x))...)
316+
if !(y isa T)
317+
_bad_tuple_conversion(T, x)
318+
end
319+
return y
320+
end
313321

314322
# used for splatting in `new`
315323
convert_prefix(::Type{Tuple{}}, x::Tuple) = x

src/subtype.c

+78-7
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ typedef struct jl_varbinding_t {
6464
int8_t occurs_inv; // occurs in invariant position
6565
int8_t occurs_cov; // # of occurrences in covariant position
6666
int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete
67+
int8_t upper_bounded; // var upper bound has been constrained
6768
// in covariant position, we need to try constraining a variable in different ways:
6869
// 0 - unconstrained
6970
// 1 - less than
@@ -145,7 +146,7 @@ static void save_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se)
145146
v = v->prev;
146147
}
147148
*root = (jl_value_t*)jl_alloc_svec(len*3);
148-
se->buf = (int8_t*)(len ? malloc(len*2) : NULL);
149+
se->buf = (int8_t*)(len ? malloc(len*3) : NULL);
149150
#ifdef __clang_analyzer__
150151
if (len)
151152
memset(se->buf, 0, len*2);
@@ -157,6 +158,7 @@ static void save_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se)
157158
jl_svecset(*root, i++, (jl_value_t*)v->innervars);
158159
se->buf[j++] = v->occurs_inv;
159160
se->buf[j++] = v->occurs_cov;
161+
se->buf[j++] = v->upper_bounded;
160162
v = v->prev;
161163
}
162164
se->rdepth = e->Runions.depth;
@@ -176,13 +178,46 @@ static void restore_env(jl_stenv_t *e, jl_value_t *root, jl_savedenv_t *se) JL_N
176178
assert(se->buf);
177179
v->occurs_inv = se->buf[j++];
178180
v->occurs_cov = se->buf[j++];
181+
v->upper_bounded = se->buf[j++];
179182
v = v->prev;
180183
}
181184
e->Runions.depth = se->rdepth;
182185
if (e->envout && e->envidx < e->envsz)
183186
memset(&e->envout[e->envidx], 0, (e->envsz - e->envidx)*sizeof(void*));
184187
}
185188

189+
// restore just occurs_inv and occurs_cov from `se` back to `e`
190+
static void restore_var_counts(jl_stenv_t *e, jl_savedenv_t *se) JL_NOTSAFEPOINT
191+
{
192+
jl_varbinding_t *v = e->vars;
193+
int j = 0;
194+
while (v != NULL) {
195+
assert(se->buf);
196+
v->occurs_inv = se->buf[j++];
197+
v->occurs_cov = se->buf[j++];
198+
j++;
199+
v = v->prev;
200+
}
201+
}
202+
203+
// compute the maximum of the occurence counts in `e` and `se`, storing them in `se`
204+
static void max_var_counts(jl_stenv_t *e, jl_savedenv_t *se) JL_NOTSAFEPOINT
205+
{
206+
jl_varbinding_t *v = e->vars;
207+
int j = 0;
208+
while (v != NULL) {
209+
assert(se->buf);
210+
if (v->occurs_inv > se->buf[j])
211+
se->buf[j] = v->occurs_inv;
212+
j++;
213+
if (v->occurs_cov > se->buf[j])
214+
se->buf[j] = v->occurs_cov;
215+
j++;
216+
j++;
217+
v = v->prev;
218+
}
219+
}
220+
186221
// type utilities
187222

188223
// quickly test that two types are identical
@@ -542,6 +577,7 @@ static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param)
542577
else {
543578
bb->ub = simple_meet(bb->ub, a);
544579
}
580+
bb->upper_bounded = 1;
545581
assert(bb->ub != (jl_value_t*)b);
546582
if (jl_is_typevar(a)) {
547583
jl_varbinding_t *aa = lookup(e, (jl_tvar_t*)a);
@@ -657,7 +693,7 @@ static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8
657693
}
658694
btemp = btemp->prev;
659695
}
660-
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, NULL, 0, 0, 0, 0, e->invdepth, 0, NULL, e->vars };
696+
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, NULL, 0, 0, 0, 0, 0, e->invdepth, 0, NULL, e->vars };
661697
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
662698
e->vars = &vb;
663699
int ans;
@@ -671,7 +707,9 @@ static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8
671707
// fill variable values into `envout` up to `envsz`
672708
if (e->envidx < e->envsz) {
673709
jl_value_t *val;
674-
if (!vb.occurs_inv && vb.lb != jl_bottom_type)
710+
if (vb.lb == vb.ub && vb.upper_bounded)
711+
val = vb.lb;
712+
else if (!vb.occurs_inv && vb.lb != jl_bottom_type)
675713
val = is_leaf_bound(vb.lb) ? vb.lb : (jl_value_t*)jl_new_typevar(u->var->name, jl_bottom_type, vb.lb);
676714
else if (vb.lb == vb.ub)
677715
val = vb.lb;
@@ -720,7 +758,7 @@ static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8
720758
else if (!is_leaf_bound(vb.lb)) {
721759
ans = 0;
722760
}
723-
if (ans) {
761+
if (ans && vb.lb != vb.ub) {
724762
// if we occur as another var's lower bound, record the fact that we
725763
// were concrete so that subtype can return true for that var.
726764
/*
@@ -731,6 +769,17 @@ static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8
731769
btemp = btemp->prev;
732770
}
733771
*/
772+
// a diagonal var cannot be >: another diagonal var at a different invariant depth, e.g.
773+
// Ref{Tuple{T,T} where T} !<: Ref{Tuple{T,T}} where T
774+
btemp = vb.prev;
775+
while (btemp != NULL) {
776+
if (btemp->lb == (jl_value_t*)u->var && btemp->depth0 != vb.depth0 &&
777+
(btemp->concrete || (btemp->occurs_cov > 1 && btemp->occurs_inv == 0))) {
778+
ans = 0;
779+
break;
780+
}
781+
btemp = btemp->prev;
782+
}
734783
}
735784
}
736785

@@ -811,7 +860,7 @@ static int subtype_tuple(jl_datatype_t *xd, jl_datatype_t *yd, jl_stenv_t *e, in
811860
else if ((vvy && ly > lx+1) || (!vvy && lx != ly)) {
812861
return 0;
813862
}
814-
param = (param == 0 ? 1 : param);
863+
param = 1;
815864
jl_value_t *lastx=NULL, *lasty=NULL;
816865
while (i < lx) {
817866
jl_value_t *xi = jl_tparam(xd, i);
@@ -1067,6 +1116,12 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
10671116
{
10681117
if (obviously_egal(x, y)) return 1;
10691118

1119+
jl_savedenv_t se; // original env
1120+
jl_savedenv_t me; // for accumulating maximum var counts
1121+
jl_value_t *saved=NULL;
1122+
save_env(e, &saved, &se);
1123+
save_env(e, &saved, &me);
1124+
10701125
jl_unionstate_t oldLunions = e->Lunions;
10711126
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
10721127
int sub;
@@ -1096,11 +1151,27 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
10961151
statestack_set(&e->Lunions, i, 0);
10971152
lastset = set - 1;
10981153
statestack_set(&e->Lunions, lastset, 1);
1154+
// take the maximum of var counts over all choices, to identify
1155+
// diagonal variables better.
1156+
max_var_counts(e, &me);
1157+
restore_var_counts(e, &se);
10991158
}
11001159
}
11011160

11021161
e->Lunions = oldLunions;
1103-
return sub && subtype(y, x, e, 0);
1162+
if (sub) {
1163+
// avoid double-counting variables when we check subtype in both directions.
1164+
// e.g. in `Ref{Tuple{T}}` the `T` occurs once even though we recursively
1165+
// call `subtype` on it twice.
1166+
max_var_counts(e, &me);
1167+
restore_var_counts(e, &se);
1168+
sub = subtype(y, x, e, 2);
1169+
max_var_counts(e, &me);
1170+
restore_var_counts(e, &me);
1171+
}
1172+
free(se.buf);
1173+
free(me.buf);
1174+
return sub;
11041175
}
11051176

11061177
static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
@@ -1961,7 +2032,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
19612032
{
19622033
jl_value_t *res=NULL, *res2=NULL, *save=NULL, *save2=NULL;
19632034
jl_savedenv_t se, se2;
1964-
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, NULL, 0, 0, 0, 0, e->invdepth, 0, NULL, e->vars };
2035+
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, NULL, 0, 0, 0, 0, 0, e->invdepth, 0, NULL, e->vars };
19652036
JL_GC_PUSH6(&res, &save2, &vb.lb, &vb.ub, &save, &vb.innervars);
19662037
save_env(e, &save, &se);
19672038
res = intersect_unionall_(t, u, e, R, param, &vb);

test/ambiguous.jl

-4
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,6 @@ end
269269
@test_broken need_to_handle_undef_sparam == Set()
270270
pop!(need_to_handle_undef_sparam, which(Core.Compiler._cat, Tuple{Any, AbstractArray}))
271271
pop!(need_to_handle_undef_sparam, first(methods(Core.Compiler.same_names)))
272-
pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{}}))
273-
pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{Int8}}))
274272
@test need_to_handle_undef_sparam == Set()
275273
end
276274
let need_to_handle_undef_sparam =
@@ -292,8 +290,6 @@ end
292290
pop!(need_to_handle_undef_sparam, which(Base.oneunit, Tuple{Type{Union{Missing, T}} where T}))
293291
pop!(need_to_handle_undef_sparam, which(Base.convert, (Type{Union{Some{T}, Nothing}} where T, Some)))
294292
pop!(need_to_handle_undef_sparam, which(Base.convert, (Type{Union{T, Nothing}} where T, Some)))
295-
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{}}))
296-
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{Int8}}))
297293
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Nothing,T}},Union{Nothing,T}} where T))
298294
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Missing,T}},Union{Missing,T}} where T))
299295
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Missing,Nothing,T}},Union{Missing,Nothing,T}} where T))

test/subtype.jl

+8
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,14 @@ function test_diagonal()
134134
@test issub(Tuple{Tuple{T, T}} where T>:Int, Tuple{Tuple{T, T} where T>:Int})
135135
@test issub(Vector{Tuple{T, T} where Number<:T<:Number},
136136
Vector{Tuple{Number, Number}})
137+
138+
@test !issub(Type{Tuple{T,Any} where T}, Type{Tuple{T,T}} where T)
139+
@test !issub(Type{Tuple{T,Any,T} where T}, Type{Tuple{T,T,T}} where T)
140+
@test issub(Type{Tuple{T} where T}, Type{Tuple{T}} where T)
141+
@test !issub(Type{Tuple{T,T} where T}, Type{Tuple{T,T}} where T)
142+
@test !issub(Type{Tuple{T,T,T} where T}, Type{Tuple{T,T,T}} where T)
143+
@test isequal_type(Ref{Tuple{T, T} where Int<:T<:Int},
144+
Ref{Tuple{S, S}} where Int<:S<:Int)
137145
end
138146

139147
# level 3: UnionAll

test/tuple.jl

+7-3
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ end
2828
@test convert(Tuple{Int, Int, Float64}, (1, 2, 3)) === (1, 2, 3.0)
2929

3030
@test convert(Tuple{Float64, Int, UInt8}, (1.0, 2, 0x3)) === (1.0, 2, 0x3)
31-
@test convert(NTuple, (1.0, 2, 0x3)) === (1.0, 2, 0x3)
31+
@test convert(Tuple{Vararg{Real}}, (1.0, 2, 0x3)) === (1.0, 2, 0x3)
32+
@test convert(Tuple{Vararg{Integer}}, (1.0, 2, 0x3)) === (1, 2, 0x3)
3233
@test convert(Tuple{Vararg{Int}}, (1.0, 2, 0x3)) === (1, 2, 3)
3334
@test convert(Tuple{Int, Vararg{Int}}, (1.0, 2, 0x3)) === (1, 2, 3)
34-
@test convert(Tuple{Vararg{T}} where T<:Integer, (1.0, 2, 0x3)) === (1, 2, 0x3)
35-
@test convert(Tuple{T, Vararg{T}} where T<:Integer, (1.0, 2, 0x3)) === (1, 2, 0x3)
3635
@test convert(NTuple{3, Int}, (1.0, 2, 0x3)) === (1, 2, 3)
3736
@test convert(Tuple{Int, Int, Float64}, (1.0, 2, 0x3)) === (1, 2, 3.0)
3837

@@ -53,6 +52,11 @@ end
5352
@test_throws MethodError convert(Tuple{Int, Int, Int}, (1, 2))
5453
# issue #26589
5554
@test_throws MethodError convert(NTuple{4}, (1.0,2.0,3.0,4.0,5.0))
55+
# issue #31824
56+
# there is no generic way to convert an arbitrary tuple to a homogeneous tuple
57+
@test_throws MethodError convert(NTuple, (1, 1.0))
58+
@test_throws MethodError convert(Tuple{Vararg{T}} where T<:Integer, (1.0, 2, 0x3)) === (1, 2, 0x3)
59+
@test_throws MethodError convert(Tuple{T, Vararg{T}} where T<:Integer, (1.0, 2, 0x3)) === (1, 2, 0x3)
5660

5761
# PR #15516
5862
@test Tuple{Char,Char}("za") === ('z','a')

0 commit comments

Comments
 (0)