Skip to content

Commit bbb9284

Browse files
committed
WIP: allow @generated begin ... end inside a function to provide an optional optimizer
1 parent 00f0d23 commit bbb9284

13 files changed

+111
-60
lines changed

base/expr.jl

+11-2
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,17 @@ end
334334

335335
macro generated(f)
336336
if isa(f, Expr) && (f.head === :function || is_short_function_def(f))
337-
pushmeta!(f, :generated)
338-
return Expr(:escape, f)
337+
body = f.args[2]
338+
lno = body.args[1]
339+
return Expr(:escape,
340+
Expr(f.head, f.args[1],
341+
Expr(:block,
342+
lno,
343+
Expr(:meta, :generator, body),
344+
Expr(:meta, :generated_only),
345+
Expr(:return, nothing))))
346+
elseif isa(f, Expr) && f.head === :block
347+
return Expr(:escape, Expr(:meta, :generator, f))
339348
else
340349
error("invalid syntax; @generated must be used with a function definition")
341350
end

base/methodshow.jl

+11-2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ function argtype_decl(env, n, sig::DataType, i::Int, nargs, isva::Bool) # -> (ar
4242
return s, string_with_env(env, t)
4343
end
4444

45+
function method_argnames(m::Method)
46+
if !isdefined(m, :source)
47+
gm = first(methods(m.generator))
48+
return method_argnames(gm)[length(m.sparam_syms)+2 : end]
49+
end
50+
argnames = Vector{Any}(m.nargs)
51+
ccall(:jl_fill_argnames, Void, (Any, Any), m.source, argnames)
52+
return argnames
53+
end
54+
4555
function arg_decl_parts(m::Method)
4656
tv = Any[]
4757
sig = m.sig
@@ -52,8 +62,7 @@ function arg_decl_parts(m::Method)
5262
file = m.file
5363
line = m.line
5464
if isdefined(m, :source) || isdefined(m, :generator)
55-
argnames = Vector{Any}(m.nargs)
56-
ccall(:jl_fill_argnames, Void, (Any, Any), isdefined(m, :source) ? m.source : m.generator.inferred, argnames)
65+
argnames = method_argnames(m)
5766
show_env = ImmutableDict{Symbol, Any}()
5867
for t in tv
5968
show_env = ImmutableDict(show_env, :unionall_env => t)

base/reflection.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,8 @@ function length(mt::MethodTable)
682682
end
683683
isempty(mt::MethodTable) = (mt.defs === nothing)
684684

685-
uncompressed_ast(m::Method) = uncompressed_ast(m, isdefined(m,:source) ? m.source : m.generator.inferred)
685+
uncompressed_ast(m::Method) = isdefined(m,:source) ? uncompressed_ast(m, m.source) :
686+
uncompressed_ast(first(methods(m.generator)))
686687
uncompressed_ast(m::Method, s::CodeInfo) = s
687688
uncompressed_ast(m::Method, s::Array{UInt8,1}) = ccall(:jl_uncompress_ast, Any, (Any, Any), m, s)::CodeInfo
688689

src/ast.c

+4-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ jl_sym_t *meta_sym; jl_sym_t *compiler_temp_sym;
5555
jl_sym_t *inert_sym; jl_sym_t *vararg_sym;
5656
jl_sym_t *unused_sym; jl_sym_t *static_parameter_sym;
5757
jl_sym_t *polly_sym; jl_sym_t *inline_sym;
58-
jl_sym_t *propagate_inbounds_sym; jl_sym_t *generated_sym;
58+
jl_sym_t *propagate_inbounds_sym; jl_sym_t *generator_sym;
59+
jl_sym_t *generated_only_sym;
5960
jl_sym_t *isdefined_sym; jl_sym_t *nospecialize_sym;
6061

6162
static uint8_t flisp_system_image[] = {
@@ -437,7 +438,8 @@ void jl_init_frontend(void)
437438
propagate_inbounds_sym = jl_symbol("propagate_inbounds");
438439
isdefined_sym = jl_symbol("isdefined");
439440
nospecialize_sym = jl_symbol("nospecialize");
440-
generated_sym = jl_symbol("generated");
441+
generator_sym = jl_symbol("generator");
442+
generated_only_sym = jl_symbol("generated_only");
441443
}
442444

443445
JL_DLLEXPORT void jl_lisp_prompt(void)

src/ast.scm

+5-2
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,11 @@
345345
(and (if one (length= e 3) (length> e 2))
346346
(eq? (car e) 'meta) (eq? (cadr e) 'nospecialize)))
347347

348-
(define (generated-meta? e)
349-
(and (pair? e) (eq? (car e) 'meta) (any (lambda (x) (eq? x 'generated)) (cdr e))))
348+
(define (generator-meta? e)
349+
(and (length= e 3) (eq? (car e) 'meta) (eq? (cadr e) 'generator)))
350+
351+
(define (generated_only-meta? e)
352+
(and (length= e 2) (eq? (car e) 'meta) (eq? (cadr e) 'generated_only)))
350353

351354
;; flatten nested expressions with the given head
352355
;; (op (op a b) c) => (op a b c)

src/codegen.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -1197,8 +1197,6 @@ jl_llvm_functions_t jl_compile_linfo(jl_method_instance_t **pli, jl_code_info_t
11971197
li->inferred &&
11981198
// and there is something to delete (test this before calling jl_ast_flag_inlineable)
11991199
li->inferred != jl_nothing &&
1200-
// don't delete the code for the generator
1201-
li != li->def.method->generator &&
12021200
// don't delete inlineable code, unless it is constant
12031201
(li->jlcall_api == 2 || !jl_ast_flag_inlineable((jl_array_t*)li->inferred)) &&
12041202
// don't delete code when generating a precompile file

src/dump.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -1417,7 +1417,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
14171417
m->unspecialized = (jl_method_instance_t*)jl_deserialize_value(s, (jl_value_t**)&m->unspecialized);
14181418
if (m->unspecialized)
14191419
jl_gc_wb(m, m->unspecialized);
1420-
m->generator = (jl_method_instance_t*)jl_deserialize_value(s, (jl_value_t**)&m->generator);
1420+
m->generator = jl_deserialize_value(s, (jl_value_t**)&m->generator);
14211421
if (m->generator)
14221422
jl_gc_wb(m, m->generator);
14231423
m->invokes.unknown = jl_deserialize_value(s, (jl_value_t**)&m->invokes);

src/jltypes.c

+1-2
Original file line numberDiff line numberDiff line change
@@ -2045,7 +2045,7 @@ void jl_init_types(void)
20452045
jl_simplevector_type,
20462046
jl_any_type,
20472047
jl_any_type, // jl_method_instance_type
2048-
jl_any_type, // jl_method_instance_type
2048+
jl_any_type,
20492049
jl_array_any_type,
20502050
jl_any_type,
20512051
jl_int32_type,
@@ -2158,7 +2158,6 @@ void jl_init_types(void)
21582158
#endif
21592159
jl_svecset(jl_methtable_type->types, 8, jl_int32_type); // uint32_t
21602160
jl_svecset(jl_method_type->types, 10, jl_method_instance_type);
2161-
jl_svecset(jl_method_type->types, 11, jl_method_instance_type);
21622161
jl_svecset(jl_method_instance_type->types, 12, jl_voidpointer_type);
21632162
jl_svecset(jl_method_instance_type->types, 13, jl_voidpointer_type);
21642163
jl_svecset(jl_method_instance_type->types, 14, jl_voidpointer_type);

src/julia-syntax.scm

+28-12
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,17 @@
294294
(map (lambda (x) (replace-outer-vars x renames))
295295
(cdr e))))))
296296

297+
(define (make-generator-function name sp-names arg-names body)
298+
(let ((arg-names (append sp-names
299+
(map (lambda (n)
300+
(if (eq? n '|#self#|) (gensy) n))
301+
arg-names))))
302+
(let ((body (insert-after-meta body ;; don't specialize on generator arguments
303+
`((meta nospecialize ,@arg-names)))))
304+
`(block
305+
(global ,name)
306+
(function (call ,name ,@arg-names) ,body)))))
307+
297308
;; construct the (method ...) expression for one primitive method definition,
298309
;; assuming optional and keyword args are already handled
299310
(define (method-def-expr- name sparams argl body (rett '(core Any)))
@@ -328,7 +339,14 @@
328339
(error "function argument and static parameter names must be distinct")))
329340
(if (or (and name (not (sym-ref? name))) (eq? name 'true) (eq? name 'false))
330341
(error (string "invalid function name \"" (deparse name) "\"")))
331-
(let* ((types (llist-types argl))
342+
(let* ((generator (let ((found (find generator-meta? body)))
343+
(if found
344+
(let* ((gname (symbol (string (gensy) "#" (current-julia-module-counter))))
345+
(gf (make-generator-function gname names (llist-vars argl) (caddr (car found)))))
346+
(set-car! (cddar found) gname)
347+
(list gf))
348+
'())))
349+
(types (llist-types argl))
332350
(body (method-lambda-expr argl body rett))
333351
;; HACK: the typevars need to be bound to ssavalues, since this code
334352
;; might be moved to a different scope by closure-convert.
@@ -360,8 +378,10 @@
360378
(call (core svec) ,@temps)))
361379
,body))))
362380
(if (symbol? name)
363-
`(block (method ,name) ,mdef (unnecessary ,name)) ;; return the function
364-
mdef)))))
381+
`(block ,@generator (method ,name) ,mdef (unnecessary ,name)) ;; return the function
382+
(if (not (null? generator))
383+
`(block ,@generator ,mdef)
384+
mdef))))))
365385

366386
;; wrap expr in nested scopes assigning names to vals
367387
(define (scopenest names vals expr)
@@ -411,11 +431,8 @@
411431
keynames))
412432
;; list of function's initial line number and meta nodes (empty if none)
413433
(prologue (extract-method-prologue body))
414-
(annotations (append (if (any generated-meta? prologue)
415-
'((meta generated))
416-
'())
417-
(map (lambda (a) `(meta nospecialize ,(arg-name (cadr (caddr a)))))
418-
(filter nospecialize-meta? kargl))))
434+
(annotations (map (lambda (a) `(meta nospecialize ,(arg-name (cadr (caddr a)))))
435+
(filter nospecialize-meta? kargl)))
419436
;; body statements
420437
(stmts (cdr body))
421438
(positional-sparams
@@ -565,10 +582,9 @@
565582
'()))
566583

567584
(define (without-generated stmts)
568-
(map (lambda (x) (if (generated-meta? x)
569-
(filter (lambda (e) (not (eq? e 'generated))) x)
570-
x))
571-
stmts))
585+
(filter (lambda (x) (not (or (generator-meta? x)
586+
(generated_only-meta? x))))
587+
stmts))
572588

573589
;; keep only sparams used by `expr` or other sparams
574590
(define (filter-sparams expr sparams)

src/julia.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ typedef struct _jl_method_t {
247247
jl_svec_t *sparam_syms; // symbols giving static parameter names
248248
jl_value_t *source; // original code template (jl_code_info_t, but may be compressed), null for builtins
249249
struct _jl_method_instance_t *unspecialized; // unspecialized executable method instance, or null
250-
struct _jl_method_instance_t *generator; // executable code-generating function if available
250+
jl_value_t *generator; // executable code-generating function if available
251251
jl_array_t *roots; // pointers in generated code (shared to reduce memory), or null
252252

253253
// cache of specializations of this method for invoke(), i.e.

src/julia_internal.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,8 @@ extern jl_sym_t *pure_sym; extern jl_sym_t *simdloop_sym;
994994
extern jl_sym_t *meta_sym; extern jl_sym_t *list_sym;
995995
extern jl_sym_t *inert_sym; extern jl_sym_t *static_parameter_sym;
996996
extern jl_sym_t *polly_sym; extern jl_sym_t *inline_sym;
997-
extern jl_sym_t *propagate_inbounds_sym; extern jl_sym_t *generated_sym;
997+
extern jl_sym_t *propagate_inbounds_sym; extern jl_sym_t *generator_sym;
998+
extern jl_sym_t *generated_only_sym;
998999
extern jl_sym_t *isdefined_sym; extern jl_sym_t *nospecialize_sym;
9991000

10001001
void jl_register_fptrs(uint64_t sysimage_base, const char *base, const int32_t *offsets,

src/method.c

+40-32
Original file line numberDiff line numberDiff line change
@@ -245,24 +245,23 @@ jl_code_info_t *jl_new_code_info_from_ast(jl_expr_t *ast)
245245
}
246246

247247
// invoke (compiling if necessary) the jlcall function pointer for a method template
248-
STATIC_INLINE jl_value_t *jl_call_staged(jl_svec_t *sparam_vals, jl_method_instance_t *generator,
248+
STATIC_INLINE jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator, jl_svec_t *sparam_vals,
249249
jl_value_t **args, uint32_t nargs)
250250
{
251-
jl_generic_fptr_t fptr;
252-
fptr.fptr = generator->fptr;
253-
fptr.jlcall_api = generator->jlcall_api;
254-
if (__unlikely(fptr.fptr == NULL || fptr.jlcall_api == 0)) {
255-
size_t world = generator->def.method->min_world;
256-
const char *F = jl_compile_linfo(&generator, (jl_code_info_t*)generator->inferred, world, &jl_default_cgparams).functionObject;
257-
fptr = jl_generate_fptr(generator, F, world);
251+
size_t spl = jl_svec_len(sparam_vals);
252+
jl_value_t **gargs;
253+
size_t totargs = 1 + spl + nargs + def->isva;
254+
JL_GC_PUSHARGS(gargs, totargs);
255+
gargs[0] = generator;
256+
memcpy(&gargs[1], jl_svec_data(sparam_vals), spl * sizeof(void*));
257+
memcpy(&gargs[1+spl], args, nargs * sizeof(void*));
258+
if (def->isva) {
259+
gargs[totargs-1] = jl_f_tuple(NULL, &gargs[1+spl+def->nargs-1], nargs - (def->nargs-1));
260+
gargs[1+spl+def->nargs-1] = gargs[totargs-1];
258261
}
259-
assert(jl_svec_len(generator->def.method->sparam_syms) == jl_svec_len(sparam_vals));
260-
if (fptr.jlcall_api == 1)
261-
return fptr.fptr1(args[0], &args[1], nargs-1);
262-
else if (fptr.jlcall_api == 3)
263-
return fptr.fptr3(sparam_vals, args[0], &args[1], nargs-1);
264-
else
265-
abort(); // shouldn't have inferred any other calling convention
262+
jl_value_t *code = jl_apply(gargs, 1+spl+def->nargs);
263+
JL_GC_POP();
264+
return code;
266265
}
267266

268267
// return a newly allocated CodeInfo for the function signature
@@ -275,9 +274,11 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
275274
jl_expr_t *ex = NULL;
276275
jl_value_t *linenum = NULL;
277276
jl_svec_t *sparam_vals = env;
278-
jl_method_instance_t *generator = linfo->def.method->generator;
277+
jl_value_t *generator = linfo->def.method->generator;
278+
jl_method_t *gen_meth = jl_gf_mtable(generator)->defs.leaf->func.method;
279279
assert(generator != NULL);
280280
assert(linfo != generator);
281+
assert(jl_is_method(gen_meth));
281282
jl_code_info_t *func = NULL;
282283
JL_GC_PUSH4(&ex, &linenum, &sparam_vals, &func);
283284
jl_ptls_t ptls = jl_get_ptls_states();
@@ -292,13 +293,14 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
292293
// need to eval macros in the right module
293294
ptls->current_task->current_module = ptls->current_module = linfo->def.method->module;
294295
// and the right world
295-
ptls->world_age = generator->def.method->min_world;
296+
ptls->world_age = gen_meth->min_world;
296297

297298
ex = jl_exprn(lambda_sym, 2);
298299

299-
jl_array_t *argnames = jl_alloc_vec_any(linfo->def.method->nargs);
300+
jl_array_t *argnames = jl_alloc_vec_any(linfo->def.method->nargs + jl_svec_len(sparam_vals) + 1);
300301
jl_array_ptr_set(ex->args, 0, argnames);
301-
jl_fill_argnames((jl_array_t*)generator->inferred, argnames);
302+
jl_fill_argnames((jl_array_t*)gen_meth->source, argnames);
303+
jl_array_del_beg(argnames, jl_svec_len(sparam_vals) + 1);
302304

303305
// build the rest of the body to pass to expand
304306
jl_expr_t *scopeblock = jl_exprn(jl_symbol("scope-block"), 1);
@@ -319,7 +321,7 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
319321
// invoke code generator
320322
assert(jl_nparams(tt) == jl_array_len(argnames) ||
321323
(linfo->def.method->isva && (jl_nparams(tt) >= jl_array_len(argnames) - 1)));
322-
jl_value_t *generated_body = jl_call_staged(sparam_vals, generator, jl_svec_data(tt->parameters), jl_nparams(tt));
324+
jl_value_t *generated_body = jl_call_staged(linfo->def.method, generator, sparam_vals, jl_svec_data(tt->parameters), jl_nparams(tt));
323325
jl_array_ptr_set(body->args, 2, generated_body);
324326

325327
if (jl_is_code_info(generated_body)) {
@@ -398,7 +400,7 @@ jl_method_instance_t *jl_get_specialized(jl_method_t *m, jl_value_t *types, jl_s
398400
return new_linfo;
399401
}
400402

401-
static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src, int *isstaged)
403+
static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src, jl_value_t **generator, int *gen_only)
402404
{
403405
uint8_t j;
404406
uint8_t called = 0;
@@ -470,12 +472,17 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src, int *issta
470472
}
471473
st = jl_nothing;
472474
}
473-
else {
474-
for (size_t j=0; j < jl_expr_nargs(st); j++) {
475-
if (jl_exprarg(st, j) == (jl_value_t*)generated_sym) {
476-
*isstaged = 1; break;
477-
}
475+
else if (jl_expr_nargs(st) == 2 && jl_exprarg(st, 0) == (jl_value_t*)generator_sym) {
476+
jl_value_t *gname = jl_exprarg(st, 1);
477+
*generator = jl_get_global(m->module, (jl_sym_t*)gname);
478+
if (*generator == NULL) {
479+
jl_error("invalid @generated function; try placing it in global scope");
478480
}
481+
st = jl_nothing;
482+
}
483+
else if (jl_expr_nargs(st) == 1 && jl_exprarg(st, 0) == (jl_value_t*)generated_only_sym) {
484+
*gen_only = 1;
485+
st = jl_nothing;
479486
}
480487
}
481488
else {
@@ -530,7 +537,6 @@ static jl_method_t *jl_new_method(
530537
jl_svec_t *tvars)
531538
{
532539
size_t i, l = jl_svec_len(tvars);
533-
int isstaged = 0;
534540
jl_svec_t *sparam_syms = jl_alloc_svec_uninit(l);
535541
for (i = 0; i < l; i++) {
536542
jl_svecset(sparam_syms, i, ((jl_tvar_t*)jl_svecref(tvars, i))->name);
@@ -547,12 +553,14 @@ static jl_method_t *jl_new_method(
547553
m->sig = (jl_value_t*)sig;
548554
m->isva = isva;
549555
m->nargs = nargs;
550-
jl_method_set_source(m, definition, &isstaged);
551-
if (isstaged) {
552-
// create and store generator for generated functions
553-
m->generator = jl_get_specialized(m, (jl_value_t*)jl_anytuple_type, jl_emptysvec);
556+
jl_value_t *gen = NULL; int gen_only = 0;
557+
jl_method_set_source(m, definition, &gen, &gen_only);
558+
if (gen) {
559+
m->generator = gen;
554560
jl_gc_wb(m, m->generator);
555-
m->generator->inferred = (jl_value_t*)m->source;
561+
}
562+
if (gen_only) {
563+
assert(gen);
556564
m->source = NULL;
557565
}
558566

src/utils.scm

+5
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,8 @@
8484
(without (cdr alst) remove)))))
8585

8686
(define (caddddr x) (car (cdr (cdr (cdr (cdr x))))))
87+
88+
(define (find p lst)
89+
(cond ((atom? lst) #f)
90+
((p (car lst)) lst)
91+
(else (find p (cdr lst)))))

0 commit comments

Comments
 (0)