Skip to content

Commit 3a412b7

Browse files
committed
use type signature instead of Method for inference-spoofing CodeInfo field [ci skip]
1 parent 3c62049 commit 3a412b7

File tree

6 files changed

+28
-19
lines changed

6 files changed

+28
-19
lines changed

base/inference.jl

+20-11
Original file line numberDiff line numberDiff line change
@@ -2085,19 +2085,28 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector
20852085
return result
20862086
end
20872087

2088-
function method_for_inference_heuristics(infstate::InferenceState)
2089-
m = infstate.src.method_for_inference_heuristics
2090-
return isa(m, Method) ? m : infstate.linfo.def
2088+
function method_for_inference_heuristics(cinfo, default::Method)::Method
2089+
if isa(cinfo, CodeInfo)
2090+
# appropriate format for `sig` is svec(ftype, argtypes, world)
2091+
sig = cinfo.signature_for_inference_heuristics
2092+
if isa(sig, SimpleVector) && length(sig) == 3
2093+
methods = _methods(sig[1], sig[2], -1, sig[3])
2094+
if length(methods) == 1
2095+
_, _, m = methods[]
2096+
if isa(m, Method)
2097+
return m
2098+
end
2099+
end
2100+
end
2101+
end
2102+
return default
20912103
end
20922104

2093-
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams, world)
2105+
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams, world)::Method
20942106
if isdefined(method, :generator) && method.generator.expand_early
20952107
method_instance = code_for_method(method, sig, sparams, world, false)
20962108
if isa(method_instance, MethodInstance)
2097-
cinfo = get_staged(method_instance)
2098-
if isa(cinfo, CodeInfo) && isa(cinfo.method_for_inference_heuristics, Method)
2099-
return cinfo.method_for_inference_heuristics
2100-
end
2109+
return method_for_inference_heuristics(get_staged(method_instance), method)
21012110
end
21022111
end
21032112
return method
@@ -2122,7 +2131,7 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
21222131
edgecycle = true
21232132
break
21242133
end
2125-
working_method = method_for_inference_heuristics(infstate)
2134+
working_method = method_for_inference_heuristics(infstate.src, infstate.linfo.def)
21262135
if checked_method === working_method
21272136
if topmost === nothing
21282137
# inspect the parent of this edge,
@@ -2142,7 +2151,7 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
21422151
# then check the parent link
21432152
if topmost === nothing && parent !== nothing
21442153
parent = parent::InferenceState
2145-
parent_method = method_for_inference_heuristics(parent)
2154+
parent_method = method_for_inference_heuristics(parent.src, parent.linfo.def)
21462155
if parent.cached && parent_method === working_method
21472156
topmost = infstate
21482157
edgecycle = true
@@ -3404,7 +3413,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool,
34043413
method = linfo.def::Method
34053414
tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
34063415
tree.code = Any[ Expr(:return, quoted(linfo.inferred_const)) ]
3407-
tree.method_for_inference_heuristics = nothing
3416+
tree.signature_for_inference_heuristics = nothing
34083417
tree.slotnames = Any[ compiler_temp_sym for i = 1:method.nargs ]
34093418
tree.slotflags = UInt8[ 0 for i = 1:method.nargs ]
34103419
tree.slottypes = nothing

src/jltypes.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -2046,7 +2046,7 @@ void jl_init_types(void)
20462046
jl_any_type, jl_emptysvec,
20472047
jl_perm_symsvec(10,
20482048
"code",
2049-
"method_for_inference_heuristics"
2049+
"signature_for_inference_heuristics"
20502050
"slottypes",
20512051
"ssavaluetypes",
20522052
"slotflags",

src/julia.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ typedef struct _jl_llvm_functions_t {
229229
// This type describes a single function body
230230
typedef struct _jl_code_info_t {
231231
jl_array_t *code; // Any array of statements
232-
jl_value_t *method_for_inference_heuristics; // optional method used during inference
232+
jl_value_t *signature_for_inference_heuristics; // optional method used during inference
233233
jl_value_t *slottypes; // types of variable slots (or `nothing`)
234234
jl_value_t *ssavaluetypes; // types of ssa values (or count of them)
235235
jl_array_t *slotflags; // local var bit flags

src/method.c

+2-2
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ static void jl_code_info_set_ast(jl_code_info_t *li, jl_expr_t *ast)
187187
jl_array_del_end(meta, na - ins);
188188
}
189189
}
190-
li->method_for_inference_heuristics = jl_nothing;
190+
li->signature_for_inference_heuristics = jl_nothing;
191191
jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ast, 1);
192192
jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0);
193193
size_t nslots = jl_array_len(vis);
@@ -256,7 +256,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
256256
(jl_code_info_t*)jl_gc_alloc(ptls, sizeof(jl_code_info_t),
257257
jl_code_info_type);
258258
src->code = NULL;
259-
src->method_for_inference_heuristics = NULL;
259+
src->signature_for_inference_heuristics = NULL;
260260
src->slotnames = NULL;
261261
src->slotflags = NULL;
262262
src->slottypes = NULL;

src/toplevel.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ static jl_code_info_t *expr_to_code_info(jl_value_t *expr)
545545
jl_gc_wb(src, src->slotflags);
546546
src->ssavaluetypes = jl_box_long(0);
547547
jl_gc_wb(src, src->ssavaluetypes);
548-
src->method_for_inference_heuristics = jl_nothing;
548+
src->signature_for_inference_heuristics = jl_nothing;
549549

550550
JL_GC_POP();
551551
return src;

test/inference.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1324,7 +1324,7 @@ end
13241324

13251325
function f24852_gen_cinfo_inflated(X, Y, f, x, y)
13261326
method, code_info = f24852_kernel_cinfo(x, y)
1327-
code_info.method_for_inference_heuristics = method
1327+
code_info.signature_for_inference_heuristics = Core.Inference.svec(f, (x, y), typemax(UInt))
13281328
return code_info
13291329
end
13301330

@@ -1365,7 +1365,7 @@ end
13651365
x, y = rand(), rand()
13661366
result = f24852_kernel(x, y)
13671367

1368-
# TODO: The commented out tests here are the ones where `method_for_inference_heuristics`
1368+
# TODO: The commented out tests here are the ones where `signature_for_inference_heuristics`
13691369
# is inflated; these tests cause segfaults. Probably due to incorrect CodeInfo
13701370
# construction/initialization happening somewhere...
13711371

@@ -1377,5 +1377,5 @@ result = f24852_kernel(x, y)
13771377
@test result === f24852_early_uninflated(x, y)
13781378
# @test result === f24852_early_inflated(x, y)
13791379

1380-
# TODO: test that `expand_early = true` + inflated `method_for_inference_heuristics`
1380+
# TODO: test that `expand_early = true` + inflated `signature_for_inference_heuristics`
13811381
# can be used to tighten up some inference result.

0 commit comments

Comments
 (0)