Skip to content

Commit 274d80e

Browse files
authored
Align module base between invalidation and edge tracking (#57625)
Our implicit edge tracking for bindings does not explicitly store any edges for bindings in the *current* module. The idea behind this is that this is a good time-space tradeoff for validation, because substantially all binding references in a module will be to its defining module, while the total number of methods within a module is limited and substantially smaller than the total number of methods in the entire system. However, we have an issue where the code that stores these edges and the invalidation code disagree on which module is the *current* one. The edge storing code was using the module in which the method was defined, while the invalidation code was using the one in which the MethodTable is defined. With these being misaligned, we can miss necessary invalidations. Both options are in principle possible, but I think the former is better, because the module in which the method is defined is also the module that we are likely to have a lot of references to (since they get referenced implicitly by just writing symbols in the code). However, this presents a problem: We don't actually have a way to iterate all the methods defined in a particular module, without just doing the brute force thing of scanning all methods and filtering. To address this, build on the deferred scanning code added in #57615 to also add any scanned modules to an explicit list in `Module`. This costs some space, but only proportional to the number of defined methods, (and thus proportional to the written source code). Note that we don't actually observe any issues in the test suite on master due to this bug. However, this is because we are grossly over-invalidating, which hides the missing invalidations from this issue (#57617).
2 parents 0c0419e + a8a1c3b commit 274d80e

14 files changed

+264
-64
lines changed

Compiler/src/abstractinterpretation.jl

+10-8
Original file line numberDiff line numberDiff line change
@@ -3633,18 +3633,20 @@ scan_partitions(query::Function, interp, g::GlobalRef, wwr::WorldWithRange) =
36333633
abstract_load_all_consistent_leaf_partitions(interp, g::GlobalRef, wwr::WorldWithRange) =
36343634
scan_leaf_partitions(abstract_eval_partition_load, interp, g, wwr)
36353635

3636+
function abstract_eval_globalref_partition(interp, binding::Core.Binding, partition::Core.BindingPartition)
3637+
# For inference purposes, we don't particularly care which global binding we end up loading, we only
3638+
# care about its type. However, we would still like to terminate the world range for the particular
3639+
# binding we end up reaching such that codegen can emit a simpler pointer load.
3640+
Pair{RTEffects, Union{Nothing, Core.Binding}}(
3641+
abstract_eval_partition_load(interp, partition),
3642+
binding_kind(partition) in (PARTITION_KIND_GLOBAL, PARTITION_KIND_DECLARED) ? binding : nothing)
3643+
end
3644+
36363645
function abstract_eval_globalref(interp, g::GlobalRef, saw_latestworld::Bool, sv::AbsIntState)
36373646
if saw_latestworld
36383647
return RTEffects(Any, Any, generic_getglobal_effects)
36393648
end
3640-
(valid_worlds, (ret, binding_if_global)) = scan_leaf_partitions(interp, g, sv.world) do interp, binding, partition
3641-
# For inference purposes, we don't particularly care which global binding we end up loading, we only
3642-
# care about its type. However, we would still like to terminate the world range for the particular
3643-
# binding we end up reaching such that codegen can emit a simpler pointer load.
3644-
Pair{RTEffects, Union{Nothing, Core.Binding}}(
3645-
abstract_eval_partition_load(interp, partition),
3646-
binding_kind(partition) in (PARTITION_KIND_GLOBAL, PARTITION_KIND_DECLARED) ? binding : nothing)
3647-
end
3649+
(valid_worlds, (ret, binding_if_global)) = scan_leaf_partitions(abstract_eval_globalref_partition, interp, g, sv.world)
36483650
update_valid_age!(sv, valid_worlds)
36493651
if ret.rt !== Union{} && ret.exct === UndefVarError && binding_if_global !== nothing && InferenceParams(interp).assume_bindings_static
36503652
if isdefined(binding_if_global, :value)

Compiler/src/utilities.jl

+19
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,25 @@ function retrieve_code_info(mi::MethodInstance, world::UInt)
129129
else
130130
c = copy(src::CodeInfo)
131131
end
132+
if (def.did_scan_source & 0x1) == 0x0
133+
# This scan must happen:
134+
# 1. After method definition
135+
# 2. Before any code instances that may have relied on information
136+
# from implicit GlobalRefs for this method are added to the cache
137+
# 3. Preferably while the IR is already uncompressed
138+
# 4. As late as possible, as early adding of the backedges may cause
139+
# spurious invalidations.
140+
#
141+
# At the moment we do so here, because
142+
# 1. It's reasonably late
143+
# 2. It has easy access to the uncompressed IR
144+
# 3. We necessarily pass through here before relying on any
145+
# information obtained from implicit GlobalRefs.
146+
#
147+
# However, the exact placement of this scan is not as important as
148+
# long as the above conditions are met.
149+
ccall(:jl_scan_method_source_now, Cvoid, (Any, Any), def, c)
150+
end
132151
end
133152
if c isa CodeInfo
134153
c.parent = mi

base/expr.jl

+4
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,10 @@ function make_atomic(order, ex)
13521352
op = :+
13531353
elseif ex.head === :(-=)
13541354
op = :-
1355+
elseif ex.head === :(|=)
1356+
op = :|
1357+
elseif ex.head === :(&=)
1358+
op = :&
13551359
elseif @isdefined string
13561360
shead = string(ex.head)
13571361
if endswith(shead, '=')

base/invalidation.jl

+38-21
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,34 @@ function invalidate_method_for_globalref!(gr::GlobalRef, method::Method, invalid
113113
end
114114
end
115115

116-
function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core.BindingPartition, new_bpart::Union{Core.BindingPartition, Nothing}, new_max_world::UInt)
116+
export_affecting_partition_flags(bpart::Core.BindingPartition) =
117+
((bpart.kind & PARTITION_MASK_KIND) == PARTITION_KIND_GUARD,
118+
(bpart.kind & PARTITION_FLAG_EXPORTED) != 0,
119+
(bpart.kind & PARTITION_FLAG_DEPRECATED) != 0)
120+
121+
function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core.BindingPartition, new_bpart::Core.BindingPartition, new_max_world::UInt)
117122
gr = b.globalref
118-
if !is_some_guard(binding_kind(invalidated_bpart))
119-
# TODO: We may want to invalidate for these anyway, since they have performance implications
120-
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
121-
for method in MethodList(mt)
123+
124+
(_, (ib, ibpart)) = Compiler.walk_binding_partition(b, invalidated_bpart, new_max_world)
125+
(_, (nb, nbpart)) = Compiler.walk_binding_partition(b, new_bpart, new_max_world+1)
126+
127+
# abstract_eval_globalref_partition is the maximum amount of information that inference
128+
# reads from a binding partition. If this information does not change - we do not need to
129+
# invalidate any code that inference created, because we know that the result will not change.
130+
need_to_invalidate_code =
131+
Compiler.abstract_eval_globalref_partition(nothing, ib, ibpart) !==
132+
Compiler.abstract_eval_globalref_partition(nothing, nb, nbpart)
133+
134+
need_to_invalidate_export = export_affecting_partition_flags(invalidated_bpart) !==
135+
export_affecting_partition_flags(new_bpart)
136+
137+
if need_to_invalidate_code
138+
if (b.flags & BINDING_FLAG_ANY_IMPLICIT_EDGES) != 0
139+
nmethods = ccall(:jl_module_scanned_methods_length, Csize_t, (Any,), gr.mod)
140+
for i = 1:nmethods
141+
method = ccall(:jl_module_scanned_methods_getindex, Any, (Any, Csize_t), gr.mod, i)::Method
122142
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
123143
end
124-
return true
125144
end
126145
if isdefined(b, :backedges)
127146
for edge in b.backedges
@@ -133,45 +152,43 @@ function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core
133152
latest_bpart.max_world == typemax(UInt) || continue
134153
is_some_imported(binding_kind(latest_bpart)) || continue
135154
partition_restriction(latest_bpart) === b || continue
136-
invalidate_code_for_globalref!(edge, latest_bpart, nothing, new_max_world)
155+
invalidate_code_for_globalref!(edge, latest_bpart, latest_bpart, new_max_world)
137156
else
138157
invalidate_method_for_globalref!(gr, edge::Method, invalidated_bpart, new_max_world)
139158
end
140159
end
141160
end
142161
end
143-
if (invalidated_bpart.kind & PARTITION_FLAG_EXPORTED != 0) || (new_bpart !== nothing && (new_bpart.kind & PARTITION_FLAG_EXPORTED != 0))
162+
163+
if need_to_invalidate_code || need_to_invalidate_export
144164
# This binding was exported - we need to check all modules that `using` us to see if they
145165
# have a binding that is affected by this change.
146166
usings_backedges = ccall(:jl_get_module_usings_backedges, Any, (Any,), gr.mod)
147167
if usings_backedges !== nothing
148-
for user in usings_backedges::Vector{Any}
168+
for user::Module in usings_backedges::Vector{Any}
149169
user_binding = ccall(:jl_get_module_binding_or_nothing, Any, (Any, Any), user, gr.name)
150170
user_binding === nothing && continue
151171
isdefined(user_binding, :partitions) || continue
152172
latest_bpart = user_binding.partitions
153173
latest_bpart.max_world == typemax(UInt) || continue
154174
binding_kind(latest_bpart) in (PARTITION_KIND_IMPLICIT, PARTITION_KIND_FAILED, PARTITION_KIND_GUARD) || continue
155-
@atomic :release latest_bpart.max_world = new_max_world
156-
invalidate_code_for_globalref!(convert(Core.Binding, user_binding), latest_bpart, nothing, new_max_world)
175+
new_bpart = need_to_invalidate_export ?
176+
ccall(:jl_maybe_reresolve_implicit, Any, (Any, Any, Csize_t), user_binding, latest_bpart, new_max_world) :
177+
latest_bpart
178+
if need_to_invalidate_code || new_bpart !== latest_bpart
179+
invalidate_code_for_globalref!(convert(Core.Binding, user_binding), latest_bpart, new_bpart, new_max_world)
180+
end
157181
end
158182
end
159183
end
160184
end
161185
invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_bpart::Core.BindingPartition, new_max_world::UInt) =
162186
invalidate_code_for_globalref!(convert(Core.Binding, gr), invalidated_bpart, new_bpart, new_max_world)
163187

164-
gr_needs_backedge_in_module(gr::GlobalRef, mod::Module) = gr.mod !== mod
165-
166-
# N.B.: This needs to match jl_maybe_add_binding_backedge
167188
function maybe_add_binding_backedge!(b::Core.Binding, edge::Union{Method, CodeInstance})
168-
method = isa(edge, Method) ? edge : edge.def.def::Method
169-
gr_needs_backedge_in_module(b.globalref, method.module) || return
170-
if !isdefined(b, :backedges)
171-
b.backedges = Any[]
172-
end
173-
!isempty(b.backedges) && b.backedges[end] === edge && return
174-
push!(b.backedges, edge)
189+
meth = isa(edge, Method) ? edge : Compiler.get_ci_mi(edge).def
190+
ccall(:jl_maybe_add_binding_backedge, Cint, (Any, Any, Any), b, edge, meth)
191+
return nothing
175192
end
176193

177194
function binding_was_invalidated(b::Core.Binding)

base/runtime_internals.jl

+2
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ const PARTITION_FLAG_DEPWARN = 0x40
216216
const PARTITION_MASK_KIND = 0x0f
217217
const PARTITION_MASK_FLAG = 0xf0
218218

219+
const BINDING_FLAG_ANY_IMPLICIT_EDGES = 0x8
220+
219221
is_defined_const_binding(kind::UInt8) = (kind == PARTITION_KIND_CONST || kind == PARTITION_KIND_CONST_IMPORT || kind == PARTITION_KIND_BACKDATED_CONST)
220222
is_some_const_binding(kind::UInt8) = (is_defined_const_binding(kind) || kind == PARTITION_KIND_UNDEF_CONST)
221223
is_some_imported(kind::UInt8) = (kind == PARTITION_KIND_IMPLICIT || kind == PARTITION_KIND_EXPLICIT || kind == PARTITION_KIND_IMPORTED)

src/gc-stock.c

+3
Original file line numberDiff line numberDiff line change
@@ -2147,6 +2147,9 @@ STATIC_INLINE void gc_mark_module_binding(jl_ptls_t ptls, jl_module_t *mb_parent
21472147
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->usings_backedges);
21482148
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->usings_backedges, &nptr);
21492149
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->usings_backedges);
2150+
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->scanned_methods);
2151+
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->scanned_methods, &nptr);
2152+
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->scanned_methods);
21502153
size_t nusings = module_usings_length(mb_parent);
21512154
if (nusings > 0) {
21522155
// this is only necessary because bindings for "using" modules

src/gf.c

+37-12
Original file line numberDiff line numberDiff line change
@@ -1839,7 +1839,7 @@ JL_DLLEXPORT jl_value_t *jl_debug_method_invalidation(int state)
18391839
return jl_nothing;
18401840
}
18411841

1842-
static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth);
1842+
static void _invalidate_backedges(jl_method_instance_t *replaced_mi, jl_code_instance_t *replaced_ci, size_t max_world, int depth);
18431843

18441844
// recursively invalidate cached methods that had an edge to a replaced method
18451845
static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_world, int depth)
@@ -1858,13 +1858,15 @@ static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_wo
18581858
if (!jl_is_method(replaced_mi->def.method))
18591859
return; // shouldn't happen, but better to be safe
18601860
JL_LOCK(&replaced_mi->def.method->writelock);
1861-
if (jl_atomic_load_relaxed(&replaced->max_world) == ~(size_t)0) {
1861+
size_t replacedmaxworld = jl_atomic_load_relaxed(&replaced->max_world);
1862+
if (replacedmaxworld == ~(size_t)0) {
18621863
assert(jl_atomic_load_relaxed(&replaced->min_world) - 1 <= max_world && "attempting to set illogical world constraints (probable race condition)");
18631864
jl_atomic_store_release(&replaced->max_world, max_world);
1865+
// recurse to all backedges to update their valid range also
1866+
_invalidate_backedges(replaced_mi, replaced, max_world, depth + 1);
1867+
} else {
1868+
assert(jl_atomic_load_relaxed(&replaced->max_world) <= max_world);
18641869
}
1865-
assert(jl_atomic_load_relaxed(&replaced->max_world) <= max_world);
1866-
// recurse to all backedges to update their valid range also
1867-
_invalidate_backedges(replaced_mi, max_world, depth + 1);
18681870
JL_UNLOCK(&replaced_mi->def.method->writelock);
18691871
}
18701872

@@ -1873,19 +1875,42 @@ JL_DLLEXPORT void jl_invalidate_code_instance(jl_code_instance_t *replaced, size
18731875
invalidate_code_instance(replaced, max_world, 1);
18741876
}
18751877

1876-
static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth) {
1878+
static void _invalidate_backedges(jl_method_instance_t *replaced_mi, jl_code_instance_t *replaced_ci, size_t max_world, int depth) {
18771879
jl_array_t *backedges = replaced_mi->backedges;
18781880
if (backedges) {
18791881
// invalidate callers (if any)
18801882
replaced_mi->backedges = NULL;
18811883
JL_GC_PUSH1(&backedges);
18821884
size_t i = 0, l = jl_array_nrows(backedges);
1885+
size_t ins = 0;
18831886
jl_code_instance_t *replaced;
18841887
while (i < l) {
1885-
i = get_next_edge(backedges, i, NULL, &replaced);
1888+
jl_value_t *invokesig = NULL;
1889+
i = get_next_edge(backedges, i, &invokesig, &replaced);
18861890
JL_GC_PROMISE_ROOTED(replaced); // propagated by get_next_edge from backedges
1891+
if (replaced_ci) {
1892+
// If we're invalidating a particular codeinstance, only invalidate
1893+
// this backedge it actually has an edge for our codeinstance.
1894+
jl_svec_t *edges = jl_atomic_load_relaxed(&replaced->edges);
1895+
for (size_t j = 0; j < jl_svec_len(edges); ++j) {
1896+
jl_value_t *edge = jl_svecref(edges, j);
1897+
if (edge == (jl_value_t*)replaced_mi || edge == (jl_value_t*)replaced_ci)
1898+
goto found;
1899+
}
1900+
// Keep this entry in the backedge list, but compact it
1901+
ins = set_next_edge(backedges, ins, invokesig, replaced);
1902+
continue;
1903+
found:;
1904+
}
18871905
invalidate_code_instance(replaced, max_world, depth);
18881906
}
1907+
if (replaced_ci && ins != 0) {
1908+
jl_array_del_end(backedges, l - ins);
1909+
// If we're only invalidating one ci, we don't know which ci any particular
1910+
// backedge was for, so we can't delete them. Put them back.
1911+
replaced_mi->backedges = backedges;
1912+
jl_gc_wb(replaced_mi, backedges);
1913+
}
18891914
JL_GC_POP();
18901915
}
18911916
}
@@ -1894,7 +1919,7 @@ static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_
18941919
static void invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, const char *why)
18951920
{
18961921
JL_LOCK(&replaced_mi->def.method->writelock);
1897-
_invalidate_backedges(replaced_mi, max_world, 1);
1922+
_invalidate_backedges(replaced_mi, NULL, max_world, 1);
18981923
JL_UNLOCK(&replaced_mi->def.method->writelock);
18991924
if (why && _jl_debug_method_invalidation) {
19001925
jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)replaced_mi);
@@ -1928,8 +1953,8 @@ JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee,
19281953
size_t i = 0, l = jl_array_nrows(callee->backedges);
19291954
for (i = 0; i < l; i++) {
19301955
// optimized version of while (i < l) i = get_next_edge(callee->backedges, i, &invokeTypes, &mi);
1931-
jl_value_t *mi = jl_array_ptr_ref(callee->backedges, i);
1932-
if (mi != (jl_value_t*)caller)
1956+
jl_value_t *ciedge = jl_array_ptr_ref(callee->backedges, i);
1957+
if (ciedge != (jl_value_t*)caller)
19331958
continue;
19341959
jl_value_t *invokeTypes = i > 0 ? jl_array_ptr_ref(callee->backedges, i - 1) : NULL;
19351960
if (invokeTypes && jl_is_method_instance(invokeTypes))
@@ -2372,7 +2397,7 @@ void jl_method_table_activate(jl_methtable_t *mt, jl_typemap_entry_t *newentry)
23722397
continue;
23732398
loctag = jl_atomic_load_relaxed(&m->specializations); // use loctag for a gcroot
23742399
_Atomic(jl_method_instance_t*) *data;
2375-
size_t i, l;
2400+
size_t l;
23762401
if (jl_is_svec(loctag)) {
23772402
data = (_Atomic(jl_method_instance_t*)*)jl_svec_data(loctag);
23782403
l = jl_svec_len(loctag);
@@ -2382,7 +2407,7 @@ void jl_method_table_activate(jl_methtable_t *mt, jl_typemap_entry_t *newentry)
23822407
l = 1;
23832408
}
23842409
enum morespec_options ambig = morespec_unknown;
2385-
for (i = 0; i < l; i++) {
2410+
for (size_t i = 0; i < l; i++) {
23862411
jl_method_instance_t *mi = jl_atomic_load_relaxed(&data[i]);
23872412
if ((jl_value_t*)mi == jl_nothing)
23882413
continue;

src/jltypes.c

+5-3
Original file line numberDiff line numberDiff line change
@@ -3275,7 +3275,7 @@ void jl_init_types(void) JL_GC_DISABLED
32753275
jl_svec(5, jl_any_type/*jl_globalref_type*/, jl_any_type, jl_binding_partition_type,
32763276
jl_any_type, jl_uint8_type),
32773277
jl_emptysvec, 0, 1, 0);
3278-
const static uint32_t binding_atomicfields[] = { 0x0005 }; // Set fields 2, 3 as atomic
3278+
const static uint32_t binding_atomicfields[] = { 0x0016 }; // Set fields 2, 3, 5 as atomic
32793279
jl_binding_type->name->atomicfields = binding_atomicfields;
32803280
const static uint32_t binding_constfields[] = { 0x0001 }; // Set fields 1 as constant
32813281
jl_binding_type->name->constfields = binding_constfields;
@@ -3539,7 +3539,7 @@ void jl_init_types(void) JL_GC_DISABLED
35393539
jl_method_type =
35403540
jl_new_datatype(jl_symbol("Method"), core,
35413541
jl_any_type, jl_emptysvec,
3542-
jl_perm_symsvec(31,
3542+
jl_perm_symsvec(32,
35433543
"name",
35443544
"module",
35453545
"file",
@@ -3568,10 +3568,11 @@ void jl_init_types(void) JL_GC_DISABLED
35683568
"isva",
35693569
"is_for_opaque_closure",
35703570
"nospecializeinfer",
3571+
"did_scan_source",
35713572
"constprop",
35723573
"max_varargs",
35733574
"purity"),
3574-
jl_svec(31,
3575+
jl_svec(32,
35753576
jl_symbol_type,
35763577
jl_module_type,
35773578
jl_symbol_type,
@@ -3602,6 +3603,7 @@ void jl_init_types(void) JL_GC_DISABLED
36023603
jl_bool_type,
36033604
jl_uint8_type,
36043605
jl_uint8_type,
3606+
jl_uint8_type,
36053607
jl_uint16_type),
36063608
jl_emptysvec,
36073609
0, 1, 10);

0 commit comments

Comments
 (0)