Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: controlling dispatch with varargs of defined length #10691

Closed
wants to merge 12 commits into from
Closed
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ New language features
and macros in packages and user code ([#8791]). Type `?@doc` at the repl
to see the current syntax and more information.

* Varargs functions may now declare the varargs length as `x...N` to
restrict dispatch.

Language changes
----------------

Expand Down
51 changes: 48 additions & 3 deletions base/inference.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
### Infer the types of variables in functions, given the types of the arguments

# A few general notes:
# - The entry point to the inference code is typeinf_ext, which gets called from
# C and performs type inference on the AST and supplied argument types.
# - This works by "simulating" your code at the level of expressions.
# The abstract_eval_* series of functions computes the types of the objects that would
# be produced from eval-ing the corresponding individual expressions.
# - Inlining is performed during type inference; the main entry point is inlining_pass
# called from typeinf_uncached.
# - Julia's type system can be thought of in terms of sets. Any corresponds to the
# entire domain (any possible type), and Bottom corresponds to the empty set.
# Bottom is used to indicate that two types don't match, i.e., their intersection
# is the empty set; it also applies to functions that return nothing.
# Any is used when types can't be inferred, because potentially any type in the domain
# of all types might be returned.
# - During julia's bootstrap, type inference is initially off, but it gets turned on
# by the ccall at the end of this file.

# parameters limiting potentially-infinite types
const MAX_TYPEUNION_LEN = 3
const MAX_TYPE_DEPTH = 4
Expand Down Expand Up @@ -103,6 +122,10 @@ isType(t::ANY) = isa(t,DataType) && is((t::DataType).name,Type.name)

isvarargtype(t::ANY) = isa(t,DataType)&&is((t::DataType).name,Vararg.name)

## t_func is a cache of type information for Julia's builtins. Each entry
# is a 3-tuple (min_nargs, max_nargs, typefun), where typefun is a function
# that computes the return type from the input types. When named, these functions
# often have the pattern fname_tfunc, where fname is the corresponding function.
const t_func = ObjectIdDict()
#t_func[tuple] = (0, Inf, (args...)->limit_tuple_depth(args))
t_func[throw] = (1, 1, x->Bottom)
Expand Down Expand Up @@ -188,7 +211,8 @@ const typeof_tfunc = function (t)
Type{typeof(t)}
end
elseif isvarargtype(t)
Vararg{typeof_tfunc(t.parameters[1])}
length(t.parameters) == 1 ? Vararg{typeof_tfunc(t.parameters[1])} :
Vararg{typeof_tfunc(t.parameters[1]), t.parameters[2]}
elseif isa(t,DataType)
if isleaftype(t)
Type{t}
Expand All @@ -206,8 +230,10 @@ const typeof_tfunc = function (t)
end
end
t_func[typeof] = (1, 1, typeof_tfunc)
# involving constants: typeassert, tupleref, getfield, fieldtype, apply_type
# therefore they get their arguments unevaluated

# The following involve constants: typeassert, tupleref, getfield, fieldtype, apply_type
# For example, getfield(obj, i) can be inferred only if we know the value (not just type)
# of i. Therefore, these inference functions also receive their argument values (the variable A).
t_func[typeassert] =
(2, 2, (A, v, t)->(isType(t) ? typeintersect(v,t.parameters[1]) :
isa(t,Tuple) && all(isType,t) ?
Expand Down Expand Up @@ -523,6 +549,10 @@ function tuple_tfunc(argtypes::ANY, limit)
return t
end

# Perform type inference on a builtin function f for argument types argtypes
# For those builtins that cannot be inferred without knowing the values
# of the arguments (e.g., getfield(obj, i)), also pass the argument values
# in args.
function builtin_tfunction(f::ANY, args::ANY, argtypes::ANY)
isva = isvatuple(argtypes)
if is(f,tuple)
Expand Down Expand Up @@ -660,6 +690,8 @@ const limit_tuple_type_n = function (t::Tuple, lim::Int)
return t
end

# Return the instantiation of a method m, given the argument types tt. env contains
# method parameters, if any.
let stagedcache=Dict{Any,Any}()
global func_for_method
function func_for_method(m::Method, tt, env)
Expand All @@ -681,6 +713,8 @@ let stagedcache=Dict{Any,Any}()
end
end

# f is the function, fargs holds the argument values, argtypes is a tuple of argument types,
# and e is the expression that defines the function call
function abstract_call_gf(f, fargs, argtypes, e)
if length(argtypes)>1 && (argtypes[1] <: Tuple) && argtypes[2]===Int
# allow tuple indexing functions to take advantage of constant
Expand Down Expand Up @@ -802,6 +836,7 @@ function abstract_call_gf(f, fargs, argtypes, e)
return rettype
end

# f is the function, types is a tuple of types in the signature, and argtypes is a tuple of argument types for the call
function invoke_tfunc(f, types, argtypes)
argtypes = typeintersect(types,limit_tuple_type(argtypes))
if is(argtypes,Bottom)
Expand Down Expand Up @@ -879,6 +914,10 @@ function abstract_apply(af, aargtypes, vtypes, sv, e)
return abstract_call(af, (), Tuple, vtypes, sv, ())
end

# Main entry point for inference on a function-call
# f is the function, fargs holds the argument values, argtypes holds the argument types,
# vtypes is an ObjectIdDict of variables and their types, sv contains similar information as vtypes,
# and e is the expression that defines the function call.
function abstract_call(f, fargs, argtypes, vtypes, sv::StaticVarInfo, e)
if is(f,_apply) && length(fargs)>1
a2type = argtypes[2]
Expand Down Expand Up @@ -959,6 +998,7 @@ function abstract_call(f, fargs, argtypes, vtypes, sv::StaticVarInfo, e)
return rt
end

# Abstract evaluation of an argument in an expression
function abstract_eval_arg(a::ANY, vtypes::ANY, sv::StaticVarInfo)
t = abstract_eval(a, vtypes, sv)
if isa(t,TypeVar) && t.lb == Bottom && isleaftype(t.ub)
Expand All @@ -967,6 +1007,7 @@ function abstract_eval_arg(a::ANY, vtypes::ANY, sv::StaticVarInfo)
return t
end

# Inference on a Expr(:call, args...)
function abstract_eval_call(e, vtypes, sv::StaticVarInfo)
fargs = e.args[2:end]
argtypes = tuple([abstract_eval_arg(a, vtypes, sv) for a in fargs]...)
Expand Down Expand Up @@ -995,6 +1036,8 @@ function abstract_eval_call(e, vtypes, sv::StaticVarInfo)
return abstract_call(f, fargs, argtypes, vtypes, sv, e)
end

# Main entry point for abstract evaluation of an expression
# e is the expression; vtypes is an ObjectIdDict of variables and their types
function abstract_eval(e::ANY, vtypes, sv::StaticVarInfo)
if isa(e,QuoteNode)
return typeof((e::QuoteNode).value)
Expand Down Expand Up @@ -1340,6 +1383,8 @@ f_argnames(ast) =

is_rest_arg(arg::ANY) = (ccall(:jl_is_rest_arg,Int32,(Any,), arg) != 0)

# linfo is the "lambda info", atypes is a tuple containing the argument types,
# sparams is always empty, and def is described below
function typeinf_ext(linfo, atypes::ANY, sparams::ANY, def)
global inference_stack
last = inference_stack
Expand Down
4 changes: 4 additions & 0 deletions doc/manual/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ You can also return multiple values via an explicit usage of the

This has the exact same effect as the previous definition of ``foo``.

.. _man-vararg-functions:

Varargs Functions
-----------------

Expand Down Expand Up @@ -339,6 +341,8 @@ the zero or more values passed to ``bar`` after its first two arguments:
In all these cases, ``x`` is bound to a tuple of the trailing values
passed to ``bar``.

It is possible to constrain the number of values passed as a variable argument; this will be discussed later in :ref:`man-vararg-fixedlen`.

On the flip side, it is often handy to "splice" the values contained in
an iterable collection into a function call as individual arguments. To
do this, one also uses ``...`` but in the function call instead:
Expand Down
26 changes: 26 additions & 0 deletions doc/manual/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,32 @@ can also constrain type parameters of methods::
The ``same_type_numeric`` function behaves much like the ``same_type``
function defined above, but is only defined for pairs of numbers.

.. _man-vararg-fixedlen:

Parametrically-constrained Varargs methods
------------------------------------------

Function parameters can also be used to constrain the number of arguments that may be supplied to a "varargs" function (:ref:`man-vararg-functions`). The notation ``...N`` is used to indicate such a constraint. For example:

.. doctest::

julia> bar(a,b,x...2) = (a,b,x)

julia> bar(1,2,3)
ERROR: MethodError: `bar` has no matching method bar(::Int, ::Int, ::Int)

julia> bar(1,2,3,4)
(1,2,(3,4))

julia> bar(1,2,3,4,5)
ERROR: MethodError: `bar` has no method matching bar(::Int, ::Int, ::Int, ::Int, ::Int)

More usefully, it is possible to constrain varargs methods by a parameter. For example::

function getindex{T,N}(A::AbstractArray{T,N}, indexes::Number...N)

would be called only when the number of ``indexes`` matches the dimensionality of the array.

Note on Optional and keyword Arguments
--------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ int jl_is_rest_arg(jl_value_t *ex)
if (((jl_expr_t*)ex)->head != colons_sym) return 0;
jl_expr_t *atype = (jl_expr_t*)jl_exprarg(ex,1);
if (!jl_is_expr(atype)) return 0;
if (atype->head != call_sym || jl_array_len(atype->args) != 3)
if (atype->head != call_sym || jl_array_len(atype->args) < 3 || jl_array_len(atype->args) > 4)
return 0;
if ((jl_sym_t*)jl_exprarg(atype,1) != dots_sym)
return 0;
Expand Down
1 change: 1 addition & 0 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,7 @@ size_t jl_static_show_x(JL_STREAM *out, jl_value_t *v, int depth)
else if (jl_is_vararg_type(v)) {
n += jl_static_show_x(out, jl_tparam0(v), depth);
n += jl_printf(out, "...");
n += jl_static_show_x(out, jl_tparam1(v), depth);
}
else if (jl_is_datatype(v)) {
jl_datatype_t *dv = (jl_datatype_t*)v;
Expand Down
99 changes: 51 additions & 48 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -738,57 +738,58 @@ static jl_function_t *cache_method(jl_methtable_t *mt, jl_tuple_t *type,
// in general, here we want to find the biggest type that's not a
// supertype of any other method signatures. so far we are conservative
// and the types we find should be bigger.
if (!isstaged && jl_tuple_len(type) > mt->max_args &&
jl_is_vararg_type(jl_tupleref(decl,jl_tuple_len(decl)-1))) {
size_t nspec = mt->max_args + 2;
limited = jl_alloc_tuple(nspec);
for(i=0; i < nspec-1; i++) {
jl_tupleset(limited, i, jl_tupleref(type, i));
}
jl_value_t *lasttype = jl_tupleref(type,i-1);
// if all subsequent arguments are subtypes of lasttype, specialize
// on that instead of decl. for example, if decl is
// (Any...)
// and type is
// (Symbol, Symbol, Symbol)
// then specialize as (Symbol...), but if type is
// (Symbol, Int32, Expr)
// then specialize as (Any...)
size_t j = i;
int all_are_subtypes=1;
for(; j < jl_tuple_len(type); j++) {
if (!jl_subtype(jl_tupleref(type,j), lasttype, 0)) {
all_are_subtypes = 0;
break;
if (!isstaged && jl_tuple_len(type) > mt->max_args) {
jl_value_t *lastdeclt = jl_tupleref(decl,jl_tuple_len(decl)-1);
if (jl_is_vararg_type(lastdeclt) && !jl_is_vararg_fixedlen(lastdeclt)) {
size_t nspec = mt->max_args + 2;
limited = jl_alloc_tuple(nspec);
for(i=0; i < nspec-1; i++) {
jl_tupleset(limited, i, jl_tupleref(type, i));
}
}
type = limited;
if (all_are_subtypes) {
// avoid Type{Type{...}...}...
if (jl_is_type_type(lasttype) && jl_is_type_type(jl_tparam0(lasttype)))
lasttype = (jl_value_t*)jl_type_type;
temp = (jl_value_t*)jl_tuple1(lasttype);
jl_tupleset(type, i, jl_apply_type((jl_value_t*)jl_vararg_type,
(jl_tuple_t*)temp));
}
else {
jl_value_t *lastdeclt = jl_tupleref(decl,jl_tuple_len(decl)-1);
if (jl_tuple_len(sparams) > 0) {
lastdeclt = (jl_value_t*)
jl_instantiate_type_with((jl_value_t*)lastdeclt,
sparams->data,
jl_tuple_len(sparams)/2);
jl_value_t *lasttype = jl_tupleref(type,i-1);
// if all subsequent arguments are subtypes of lasttype, specialize
// on that instead of decl. for example, if decl is
// (Any...)
// and type is
// (Symbol, Symbol, Symbol)
// then specialize as (Symbol...), but if type is
// (Symbol, Int32, Expr)
// then specialize as (Any...)
size_t j = i;
int all_are_subtypes=1;
for(; j < jl_tuple_len(type); j++) {
if (!jl_subtype(jl_tupleref(type,j), lasttype, 0)) {
all_are_subtypes = 0;
break;
}
}
type = limited;
if (all_are_subtypes) {
// avoid Type{Type{...}...}...
if (jl_is_type_type(lasttype) && jl_is_type_type(jl_tparam0(lasttype)))
lasttype = (jl_value_t*)jl_type_type;
temp = (jl_value_t*)jl_tuple1(lasttype);
jl_tupleset(type, i, jl_apply_type((jl_value_t*)jl_vararg_type,
(jl_tuple_t*)temp));
}
else {
if (jl_tuple_len(sparams) > 0) {
lastdeclt = (jl_value_t*)
jl_instantiate_type_with((jl_value_t*)lastdeclt,
sparams->data,
jl_tuple_len(sparams)/2);
}
jl_tupleset(type, i, lastdeclt);
}
jl_tupleset(type, i, lastdeclt);
// now there is a problem: the computed signature is more
// general than just the given arguments, so it might conflict
// with another definition that doesn't have cache instances yet.
// to fix this, we insert guard cache entries for all intersections
// of this signature and definitions. those guard entries will
// supersede this one in conflicted cases, alerting us that there
// should actually be a cache miss.
need_guard_entries = 1;
}
// now there is a problem: the computed signature is more
// general than just the given arguments, so it might conflict
// with another definition that doesn't have cache instances yet.
// to fix this, we insert guard cache entries for all intersections
// of this signature and definitions. those guard entries will
// supersede this one in conflicted cases, alerting us that there
// should actually be a cache miss.
need_guard_entries = 1;
}

if (need_guard_entries) {
Expand Down Expand Up @@ -924,6 +925,8 @@ static jl_function_t *cache_method(jl_methtable_t *mt, jl_tuple_t *type,
return newmeth;
}

// a holds the argument types, b the argument signature, tvars the parameters.
// On output, *penv holds (parameter1, value1, ...) pairs from intersection.
static jl_value_t *lookup_match(jl_value_t *a, jl_value_t *b, jl_tuple_t **penv,
jl_tuple_t *tvars)
{
Expand Down
Loading