Skip to content

Commit 384a474

Browse files
committed
replace ANY with @nospecialize annotation. part of #11339
1 parent 4b345c1 commit 384a474

15 files changed

+108
-41
lines changed

base/boot.jl

+4
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ macro _noinline_meta()
196196
Expr(:meta, :noinline)
197197
end
198198

199+
macro nospecialize(x)
200+
Expr(:meta, :nospecialize, x)
201+
end
202+
199203
struct BoundsError <: Exception
200204
a::Any
201205
i::Any

base/essentials.jl

+23
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,29 @@ end
1717
macro _noinline_meta()
1818
Expr(:meta, :noinline)
1919
end
20+
"""
21+
@nospecialize
22+
23+
Applied to a function argument name, hints to the compiler that the method
24+
should not be specialized for different types of the specified argument.
25+
This is only a hint for avoiding excess code generation.
26+
Can be applied to an argument within a formal argument list, or in the
27+
function body:
28+
29+
```julia
30+
function example_function(@nospecialize x)
31+
...
32+
end
33+
34+
function example_function(x, y, z)
35+
@nospecialize x y
36+
...
37+
end
38+
```
39+
"""
40+
macro nospecialize(var, vars...)
41+
Expr(:meta, :nospecialize, var, vars...)
42+
end
2043
macro _pure_meta()
2144
Expr(:meta, :pure)
2245
end

base/exports.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,7 @@ export
12851285
@simd,
12861286
@inline,
12871287
@noinline,
1288+
@nospecialize,
12881289
@polly,
12891290

12901291
@assert,

src/ast.c

+2-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jl_sym_t *inert_sym; jl_sym_t *vararg_sym;
5656
jl_sym_t *unused_sym; jl_sym_t *static_parameter_sym;
5757
jl_sym_t *polly_sym; jl_sym_t *inline_sym;
5858
jl_sym_t *propagate_inbounds_sym;
59-
jl_sym_t *isdefined_sym;
59+
jl_sym_t *isdefined_sym; jl_sym_t *nospecialize_sym;
6060

6161
static uint8_t flisp_system_image[] = {
6262
#include <julia_flisp.boot.inc>
@@ -433,6 +433,7 @@ void jl_init_frontend(void)
433433
inline_sym = jl_symbol("inline");
434434
propagate_inbounds_sym = jl_symbol("propagate_inbounds");
435435
isdefined_sym = jl_symbol("isdefined");
436+
nospecialize_sym = jl_symbol("nospecialize");
436437
}
437438

438439
JL_DLLEXPORT void jl_lisp_prompt(void)

src/ast.scm

+11
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@
152152
(if (not (symbol? (cadr v)))
153153
(bad-formal-argument (cadr v)))
154154
(decl-var v))
155+
((meta) ;; allow certain per-argument annotations
156+
(if (and (length= v 3) (eq? (cadr v) 'nospecialize))
157+
(arg-name (caddr v))
158+
(bad-formal-argument v)))
155159
(else (bad-formal-argument v))))))
156160

157161
(define (arg-type v)
@@ -167,6 +171,10 @@
167171
(if (not (symbol? (cadr v)))
168172
(bad-formal-argument (cadr v)))
169173
(decl-type v))
174+
((meta) ;; allow certain per-argument annotations
175+
(if (and (length= v 3) (eq? (cadr v) 'nospecialize))
176+
(arg-type (caddr v))
177+
(bad-formal-argument v)))
170178
(else (bad-formal-argument v))))))
171179

172180
;; convert a lambda list into a list of just symbols
@@ -310,6 +318,9 @@
310318
(define (kwarg? e)
311319
(and (pair? e) (eq? (car e) 'kw)))
312320

321+
(define (nospecialize-meta? e)
322+
(and (length> e 2) (eq? (car e) 'meta) (eq? (cadr e) 'nospecialize)))
323+
313324
;; flatten nested expressions with the given head
314325
;; (op (op a b) c) => (op a b c)
315326
(define (flatten-ex head e)

src/gf.c

+3-15
Original file line numberDiff line numberDiff line change
@@ -611,8 +611,7 @@ static void jl_cacheable_sig(
611611

612612
int notcalled_func = (i > 0 && i <= 8 && !(definition->called & (1 << (i - 1))) &&
613613
jl_subtype(elt, (jl_value_t*)jl_function_type));
614-
if (decl_i == jl_ANY_flag) {
615-
// don't specialize on slots marked ANY
614+
if (i > 0 && i <= 8 && definition->nospec & (1 << (i - 1))) {
616615
if (!*newparams) *newparams = jl_svec_copy(type->parameters);
617616
jl_svecset(*newparams, i, (jl_value_t*)jl_any_type);
618617
*need_guard_entries = 1;
@@ -714,9 +713,8 @@ JL_DLLEXPORT int jl_is_cacheable_sig(
714713
continue;
715714
if (jl_is_kind(elt)) // kind slots always need guard entries (checking for subtypes of Type)
716715
continue;
717-
if (decl_i == jl_ANY_flag) {
718-
// don't specialize on slots marked ANY
719-
if (elt != (jl_value_t*)jl_any_type && elt != jl_ANY_flag)
716+
if (i > 0 && i <= 8 && definition->nospec & (1 << (i - 1))) {
717+
if (elt != (jl_value_t*)jl_any_type)
720718
return 0;
721719
continue;
722720
}
@@ -2258,16 +2256,6 @@ static int ml_matches_visitor(jl_typemap_entry_t *ml, struct typemap_intersectio
22582256
break;
22592257
}
22602258
}
2261-
// don't analyze slots declared with ANY
2262-
// TODO
2263-
/*
2264-
l = jl_nparams(ml->sig);
2265-
size_t m = jl_nparams(ti);
2266-
for(i=0; i < l && i < m; i++) {
2267-
if (jl_tparam(ml->sig, i) == jl_ANY_flag)
2268-
jl_tupleset(ti, i, jl_any_type);
2269-
}
2270-
*/
22712259
}
22722260
if (!skip) {
22732261
/*

src/jltypes.c

+5-7
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ static int typeenv_has(jl_typeenv_t *env, jl_tvar_t *v)
146146
static int has_free_typevars(jl_value_t *v, jl_typeenv_t *env)
147147
{
148148
if (jl_typeis(v, jl_tvar_type)) {
149-
if (v == jl_ANY_flag) return 0;
150149
return !typeenv_has(env, (jl_tvar_t*)v);
151150
}
152151
if (jl_is_uniontype(v))
@@ -181,7 +180,6 @@ JL_DLLEXPORT int jl_has_free_typevars(jl_value_t *v)
181180
static void find_free_typevars(jl_value_t *v, jl_typeenv_t *env, jl_array_t *out)
182181
{
183182
if (jl_typeis(v, jl_tvar_type)) {
184-
if (v == jl_ANY_flag) return;
185183
if (!typeenv_has(env, (jl_tvar_t*)v))
186184
jl_array_ptr_1d_push(out, v);
187185
}
@@ -238,7 +236,7 @@ static int jl_has_bound_typevars(jl_value_t *v, jl_typeenv_t *env)
238236
return ans;
239237
}
240238
if (jl_is_datatype(v)) {
241-
if (!((jl_datatype_t*)v)->hasfreetypevars && !(env && env->var == (jl_tvar_t*)jl_ANY_flag))
239+
if (!((jl_datatype_t*)v)->hasfreetypevars)
242240
return 0;
243241
size_t i;
244242
for (i=0; i < jl_nparams(v); i++) {
@@ -669,8 +667,6 @@ static int is_cacheable(jl_datatype_t *type)
669667
assert(jl_is_datatype(type));
670668
jl_svec_t *t = type->parameters;
671669
if (jl_svec_len(t) == 0) return 0;
672-
if (jl_has_typevar((jl_value_t*)type, (jl_tvar_t*)jl_ANY_flag))
673-
return 0;
674670
// cache abstract types with no free type vars
675671
if (jl_is_abstracttype(type))
676672
return !jl_has_free_typevars((jl_value_t*)type);
@@ -1939,7 +1935,7 @@ void jl_init_types(void)
19391935
jl_method_type =
19401936
jl_new_datatype(jl_symbol("Method"), core,
19411937
jl_any_type, jl_emptysvec,
1942-
jl_perm_symsvec(19,
1938+
jl_perm_symsvec(20,
19431939
"name",
19441940
"module",
19451941
"file",
@@ -1956,10 +1952,11 @@ void jl_init_types(void)
19561952
"invokes",
19571953
"nargs",
19581954
"called",
1955+
"nospec",
19591956
"isva",
19601957
"isstaged",
19611958
"pure"),
1962-
jl_svec(19,
1959+
jl_svec(20,
19631960
jl_sym_type,
19641961
jl_module_type,
19651962
jl_sym_type,
@@ -1976,6 +1973,7 @@ void jl_init_types(void)
19761973
jl_any_type,
19771974
jl_int32_type,
19781975
jl_int32_type,
1976+
jl_int32_type,
19791977
jl_bool_type,
19801978
jl_bool_type,
19811979
jl_bool_type),

src/julia-syntax.scm

+16-2
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,17 @@
10261026
(string "function Base.broadcast(::typeof(" (deparse op_) "), ...)")))
10271027
op_))
10281028
(name (if op '(|.| Base (inert broadcast)) name))
1029+
(annotations (map (lambda (a)
1030+
`(meta nospecialize ,(arg-name a)))
1031+
(filter nospecialize-meta? argl)))
1032+
(body (if (null? annotations)
1033+
(caddr e)
1034+
(insert-after-meta (caddr e) annotations)))
1035+
(argl (map (lambda (a)
1036+
(if (nospecialize-meta? a)
1037+
(caddr a)
1038+
a))
1039+
argl))
10291040
(argl (if op (cons `(|::| (call (core Typeof) ,op)) argl) argl))
10301041
(sparams (map analyze-typevar (cond (has-sp (cddr head))
10311042
(where where)
@@ -1046,7 +1057,7 @@
10461057
(name (if (or (decl? name) (and (pair? name) (eq? (car name) 'curly)))
10471058
#f name)))
10481059
(expand-forms
1049-
(method-def-expr name sparams argl (caddr e) isstaged rett))))
1060+
(method-def-expr name sparams argl body isstaged rett))))
10501061
(else
10511062
(error (string "invalid assignment location \"" (deparse name) "\""))))))
10521063

@@ -1170,7 +1181,7 @@
11701181
(|::| __module__ (core Module))
11711182
,@(map (lambda (v)
11721183
(if (symbol? v)
1173-
`(|::| ,v (core ANY))
1184+
`(|::| ,v (core ANY)) ;; TODO: ANY deprecation
11741185
v))
11751186
anames))
11761187
,@(cddr e)))))
@@ -3780,6 +3791,9 @@ f(x) = yt(x)
37803791
((and (pair? e) (eq? (car e) 'outerref))
37813792
(let ((idx (get sp-table (cadr e) #f)))
37823793
(if idx `(static_parameter ,idx) (cadr e))))
3794+
((and (length> e 2) (eq? (car e) 'meta) (eq? (cadr e) 'nospecialize))
3795+
;; convert nospecialize vars to slot numbers
3796+
`(meta nospecialize ,@(map renumber-slots (cddr e))))
37833797
((or (atom? e) (quoted? e)) e)
37843798
((ssavalue? e)
37853799
(let ((idx (or (get ssavalue-table (cadr e) #f)

src/julia.h

+1
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ typedef struct _jl_method_t {
255255

256256
int32_t nargs;
257257
int32_t called; // bit flags: whether each of the first 8 arguments is called
258+
int32_t nospec; // bit flags: which arguments should not be specialized
258259
uint8_t isva;
259260
uint8_t isstaged;
260261
uint8_t pure;

src/julia_internal.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ extern jl_sym_t *meta_sym; extern jl_sym_t *list_sym;
10011001
extern jl_sym_t *inert_sym; extern jl_sym_t *static_parameter_sym;
10021002
extern jl_sym_t *polly_sym; extern jl_sym_t *inline_sym;
10031003
extern jl_sym_t *propagate_inbounds_sym;
1004-
extern jl_sym_t *isdefined_sym;
1004+
extern jl_sym_t *isdefined_sym; extern jl_sym_t *nospecialize_sym;
10051005

10061006
void jl_register_fptrs(uint64_t sysimage_base, const char *base, const int32_t *offsets,
10071007
jl_method_instance_t **linfos, size_t n);

src/macroexpand.scm

+4
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,10 @@
192192
(case (car v)
193193
((... kw |::|) (try-arg-name (cadr v)))
194194
((escape) (list v))
195+
((meta) ;; allow certain per-argument annotations
196+
(if (and (length= v 3) (eq? (cadr v) 'nospecialize))
197+
(try-arg-name (caddr v))
198+
'()))
195199
(else '())))))
196200

197201
;; get names from a formal argument list, specifying whether to include escaped ones

src/method.c

+21-1
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,15 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src)
434434
set_lineno = 1;
435435
}
436436
}
437+
else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == meta_sym &&
438+
jl_expr_nargs(st) > 1 && jl_exprarg(st,0) == (jl_value_t*)nospecialize_sym) {
439+
for(size_t j=1; j < jl_expr_nargs(st); j++) {
440+
jl_value_t *aj = jl_exprarg(st, j);
441+
if (jl_is_slot(aj))
442+
m->nospec |= (1 << (jl_slot_number(aj) - 2));
443+
}
444+
st = jl_nothing;
445+
}
437446
else {
438447
st = jl_resolve_globals(st, m->module, sparam_vars);
439448
}
@@ -465,6 +474,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
465474
m->file = empty_sym;
466475
m->line = 0;
467476
m->called = 0xff;
477+
m->nospec = 0;
468478
m->invokes.unknown = NULL;
469479
m->isstaged = 0;
470480
m->isva = 0;
@@ -642,6 +652,16 @@ JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata,
642652
jl_methtable_t *mt;
643653
jl_sym_t *name;
644654
jl_method_t *m = NULL;
655+
size_t i, na = jl_svec_len(atypes);
656+
int32_t nospec = 0;
657+
for(i=1; i < na; i++) {
658+
jl_value_t *ti = jl_svecref(atypes, i);
659+
if (ti == jl_ANY_flag ||
660+
(jl_is_vararg_type(ti) && jl_tparam0(jl_unwrap_unionall(ti)) == jl_ANY_flag)) {
661+
nospec |= (1 << (i - 1));
662+
jl_svecset(atypes, i, jl_substitute_var(ti, (jl_tvar_t*)jl_ANY_flag, (jl_value_t*)jl_any_type));
663+
}
664+
}
645665
jl_value_t *argtype = (jl_value_t*)jl_apply_tuple_type(atypes);
646666
JL_GC_PUSH3(&f, &m, &argtype);
647667

@@ -675,6 +695,7 @@ JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata,
675695
}
676696

677697
m = jl_new_method(f, name, module, (jl_tupletype_t*)argtype, nargs, isva, tvars, isstaged == jl_true);
698+
m->nospec |= nospec;
678699

679700
if (jl_has_free_typevars(argtype)) {
680701
jl_exceptionf(jl_argumenterror_type,
@@ -686,7 +707,6 @@ JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata,
686707

687708
jl_check_static_parameter_conflicts(m, f, tvars);
688709

689-
size_t i, na = jl_svec_len(atypes);
690710
for (i = 0; i < na; i++) {
691711
jl_value_t *elt = jl_svecref(atypes, i);
692712
if (!jl_is_type(elt) && !jl_is_typevar(elt)) {

src/subtype.c

+2-10
Original file line numberDiff line numberDiff line change
@@ -790,8 +790,6 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e);
790790
// diagonal rule (record_var_occurrence).
791791
static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
792792
{
793-
if (x == jl_ANY_flag) x = (jl_value_t*)jl_any_type;
794-
if (y == jl_ANY_flag) y = (jl_value_t*)jl_any_type;
795793
if (jl_is_uniontype(x)) {
796794
if (x == y) return 1;
797795
x = pick_union_element(x, e, 0);
@@ -1826,8 +1824,6 @@ static jl_value_t *intersect_type_type(jl_value_t *x, jl_value_t *y, jl_stenv_t
18261824
static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
18271825
{
18281826
if (x == y) return y;
1829-
if (x == jl_ANY_flag) x = (jl_value_t*)jl_any_type;
1830-
if (y == jl_ANY_flag) y = (jl_value_t*)jl_any_type;
18311827
if (jl_is_typevar(x)) {
18321828
if (jl_is_typevar(y)) {
18331829
jl_varbinding_t *xx = lookup(e, (jl_tvar_t*)x);
@@ -2219,10 +2215,6 @@ JL_DLLEXPORT jl_svec_t *jl_env_from_type_intersection(jl_value_t *a, jl_value_t
22192215

22202216
static int eq_msp(jl_value_t *a, jl_value_t *b, jl_typeenv_t *env)
22212217
{
2222-
// equate ANY and Any for specificity purposes, #16153
2223-
if ((a == (jl_value_t*)jl_any_type && b == jl_ANY_flag) ||
2224-
(b == (jl_value_t*)jl_any_type && a == jl_ANY_flag))
2225-
return 1;
22262218
if (!(jl_is_type(a) || jl_is_typevar(a)) ||
22272219
!(jl_is_type(b) || jl_is_typevar(b)))
22282220
return jl_egal(a, b);
@@ -2508,8 +2500,8 @@ static int type_morespecific_(jl_value_t *a, jl_value_t *b, int invariant, jl_ty
25082500
}
25092501

25102502
if (!invariant) {
2511-
if ((jl_datatype_t*)a == jl_any_type || a == jl_ANY_flag) return 0;
2512-
if ((jl_datatype_t*)b == jl_any_type || b == jl_ANY_flag) return 1;
2503+
if ((jl_datatype_t*)a == jl_any_type) return 0;
2504+
if ((jl_datatype_t*)b == jl_any_type) return 1;
25132505
}
25142506

25152507
if (jl_is_datatype(a) && jl_is_datatype(b)) {

src/typemap.c

+3-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ extern "C" {
1919
// compute whether the specificity of this type is equivalent to Any in the sort order
2020
static int jl_is_any(jl_value_t *t1)
2121
{
22-
return (t1 == (jl_value_t*)jl_any_type || t1 == jl_ANY_flag ||
22+
return (t1 == (jl_value_t*)jl_any_type ||
2323
(jl_is_typevar(t1) &&
2424
((jl_tvar_t*)t1)->ub == (jl_value_t*)jl_any_type));
2525
}
@@ -69,7 +69,7 @@ static int sig_match_by_type_simple(jl_value_t **types, size_t n, jl_tupletype_t
6969
return 0;
7070
}
7171
}
72-
else if (decl == (jl_value_t*)jl_any_type || decl == jl_ANY_flag) {
72+
else if (decl == (jl_value_t*)jl_any_type) {
7373
}
7474
else {
7575
if (jl_is_type_type(a)) // decl is not Type, because it would be caught above
@@ -122,8 +122,7 @@ static inline int sig_match_simple(jl_value_t **args, size_t n, jl_value_t **sig
122122
for (i = 0; i < lensig; i++) {
123123
jl_value_t *decl = sig[i];
124124
jl_value_t *a = args[i];
125-
if (decl == (jl_value_t*)jl_any_type || decl == jl_ANY_flag ||
126-
((jl_value_t*)jl_typeof(a) == decl)) {
125+
if (decl == (jl_value_t*)jl_any_type || ((jl_value_t*)jl_typeof(a) == decl)) {
127126
/*
128127
we are only matching concrete types here, and those types are
129128
hash-consed, so pointer comparison should work.

0 commit comments

Comments
 (0)