Skip to content

Commit 2ecc499

Browse files
authored
Merge pull request #30577 from JuliaLang/jb/splatnew
allow splatting in calls to `new`
2 parents c6c3d72 + e456a72 commit 2ecc499

18 files changed

+200
-67
lines changed

NEWS.md

+7-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@ Julia v1.2 Release Notes
44
New language features
55
---------------------
66

7-
* The `extrema` function now accepts a function argument in the same manner as `minimum` and
8-
`maximum` ([#30323]).
9-
* `hasmethod` can now check for matching keyword argument names ([#30712]).
10-
* `startswith` and `endswith` now accept a `Regex` for the second argument ([#29790]).
7+
* Argument splatting (`x...`) can now be used in calls to the `new` pseudo-function in
8+
constructors ([#30577]).
119

1210
Multi-threading changes
1311
-----------------------
@@ -35,6 +33,11 @@ New library functions
3533
Standard library changes
3634
------------------------
3735

36+
* The `extrema` function now accepts a function argument in the same manner as `minimum` and
37+
`maximum` ([#30323]).
38+
* `hasmethod` can now check for matching keyword argument names ([#30712]).
39+
* `startswith` and `endswith` now accept a `Regex` for the second argument ([#29790]).
40+
3841
#### LinearAlgebra
3942

4043
* Added keyword arguments `rtol`, `atol` to `pinv` and `nullspace` ([#29998]).

base/boot.jl

+2-27
Original file line numberDiff line numberDiff line change
@@ -548,33 +548,8 @@ NamedTuple{names}(args::Tuple) where {names} = NamedTuple{names,typeof(args)}(ar
548548

549549
using .Intrinsics: sle_int, add_int
550550

551-
macro generated()
552-
return Expr(:generated)
553-
end
554-
555-
function NamedTuple{names,T}(args::T) where {names, T <: Tuple}
556-
if @generated
557-
N = nfields(names)
558-
flds = Array{Any,1}(undef, N)
559-
i = 1
560-
while sle_int(i, N)
561-
arrayset(false, flds, :(getfield(args, $i)), i)
562-
i = add_int(i, 1)
563-
end
564-
Expr(:new, :(NamedTuple{names,T}), flds...)
565-
else
566-
N = nfields(names)
567-
NT = NamedTuple{names,T}
568-
flds = Array{Any,1}(undef, N)
569-
i = 1
570-
while sle_int(i, N)
571-
arrayset(false, flds, getfield(args, i), i)
572-
i = add_int(i, 1)
573-
end
574-
ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), NT,
575-
ccall(:jl_array_ptr, Ptr{Cvoid}, (Any,), flds), toUInt32(N))::NT
576-
end
577-
end
551+
eval(Core, :(NamedTuple{names,T}(args::T) where {names, T <: Tuple} =
552+
$(Expr(:splatnew, :(NamedTuple{names,T}), :args))))
578553

579554
# constructors for built-in types
580555

base/compiler/abstractinterpretation.jl

+3
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,9 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
907907
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args)))
908908
end
909909
end
910+
elseif e.head === :splatnew
911+
t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1]
912+
# TODO: improve
910913
elseif e.head === :&
911914
abstract_eval(e.args[1], vtypes, sv)
912915
t = Any

base/compiler/ssair/inlining.jl

+25
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,31 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::
782782
todo = Any[]
783783
for idx in 1:length(ir.stmts)
784784
stmt = ir.stmts[idx]
785+
786+
if isexpr(stmt, :splatnew)
787+
ty = ir.types[idx]
788+
nf = nfields_tfunc(ty)
789+
if nf isa Const
790+
eargs = stmt.args
791+
tup = eargs[2]
792+
tt = argextype(tup, ir, sv.sptypes)
793+
tnf = nfields_tfunc(tt)
794+
if tnf isa Const && tnf.val <= nf.val
795+
n = tnf.val
796+
new_argexprs = Any[eargs[1]]
797+
for j = 1:n
798+
atype = getfield_tfunc(tt, Const(j))
799+
new_call = Expr(:call, Core.getfield, tup, j)
800+
new_argexpr = insert_node!(ir, idx, atype, new_call)
801+
push!(new_argexprs, new_argexpr)
802+
end
803+
stmt.head = :new
804+
stmt.args = new_argexprs
805+
end
806+
end
807+
continue
808+
end
809+
785810
isexpr(stmt, :call) || continue
786811
eargs = stmt.args
787812
isempty(eargs) && continue

base/compiler/ssair/ir.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ function getindex(x::UseRef)
325325
end
326326

327327
function is_relevant_expr(e::Expr)
328-
return e.head in (:call, :invoke, :new, :(=), :(&),
328+
return e.head in (:call, :invoke, :new, :splatnew, :(=), :(&),
329329
:gc_preserve_begin, :gc_preserve_end,
330330
:foreigncall, :isdefined, :copyast,
331331
:undefcheck, :throw_undef_if_not,

base/compiler/tfuncs.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -329,17 +329,17 @@ function sizeof_tfunc(@nospecialize(x),)
329329
return Int
330330
end
331331
add_tfunc(Core.sizeof, 1, 1, sizeof_tfunc, 0)
332-
add_tfunc(nfields, 1, 1,
333-
function (@nospecialize(x),)
334-
isa(x, Const) && return Const(nfields(x.val))
335-
isa(x, Conditional) && return Const(0)
336-
if isa(x, DataType) && !x.abstract && !(x.name === Tuple.name && isvatuple(x))
337-
if !(x.name === _NAMEDTUPLE_NAME && !isconcretetype(x))
338-
return Const(length(x.types))
339-
end
332+
function nfields_tfunc(@nospecialize(x))
333+
isa(x, Const) && return Const(nfields(x.val))
334+
isa(x, Conditional) && return Const(0)
335+
if isa(x, DataType) && !x.abstract && !(x.name === Tuple.name && isvatuple(x))
336+
if !(x.name === _NAMEDTUPLE_NAME && !isconcretetype(x))
337+
return Const(length(x.types))
340338
end
341-
return Int
342-
end, 0)
339+
end
340+
return Int
341+
end
342+
add_tfunc(nfields, 1, 1, nfields_tfunc, 0)
343343
add_tfunc(Core._expr, 1, INT_INF, (@nospecialize args...)->Expr, 100)
344344
function typevar_tfunc(@nospecialize(n), @nospecialize(lb_arg), @nospecialize(ub_arg))
345345
lb = Union{}

base/compiler/validation.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const VALID_EXPR_HEADS = IdDict{Any,Any}(
1111
:method => 1:4,
1212
:const => 1:1,
1313
:new => 1:typemax(Int),
14+
:splatnew => 2:2,
1415
:return => 1:1,
1516
:unreachable => 0:0,
1617
:the_exception => 0:0,
@@ -142,7 +143,7 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_
142143
head === :inbounds || head === :foreigncall || head === :cfunction ||
143144
head === :const || head === :enter || head === :leave || head == :pop_exception ||
144145
head === :method || head === :global || head === :static_parameter ||
145-
head === :new || head === :thunk || head === :simdloop ||
146+
head === :new || head === :splatnew || head === :thunk || head === :simdloop ||
146147
head === :throw_undef_if_not || head === :unreachable
147148
validate_val!(x)
148149
else
@@ -224,7 +225,7 @@ end
224225

225226
function is_valid_rvalue(@nospecialize(x))
226227
is_valid_argument(x) && return true
227-
if isa(x, Expr) && x.head in (:new, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
228+
if isa(x, Expr) && x.head in (:new, :splatnew, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
228229
return true
229230
end
230231
return false

base/essentials.jl

+7
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,13 @@ convert(::Type{Tuple{Vararg{V}}}, x::Tuple{Vararg{V}}) where {V} = x
280280
convert(T::Type{Tuple{Vararg{V}}}, x::Tuple) where {V} =
281281
(convert(tuple_type_head(T), x[1]), convert(T, tail(x))...)
282282

283+
# used for splatting in `new`
284+
convert_prefix(::Type{Tuple{}}, x::Tuple) = x
285+
convert_prefix(::Type{<:AtLeast1}, x::Tuple{}) = x
286+
convert_prefix(::Type{T}, x::T) where {T<:AtLeast1} = x
287+
convert_prefix(::Type{T}, x::AtLeast1) where {T<:AtLeast1} =
288+
(convert(tuple_type_head(T), x[1]), convert_prefix(tuple_type_tail(T), tail(x))...)
289+
283290
# TODO: the following definitions are equivalent (behaviorally) to the above method
284291
# I think they may be faster / more efficient for inference,
285292
# if we could enable them, but are they?

base/show.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1223,8 +1223,8 @@ function show_unquoted(io::IO, ex::Expr, indent::Int, prec::Int)
12231223
end
12241224

12251225
# new expr
1226-
elseif head === :new
1227-
show_enclosed_list(io, "%new(", args, ", ", ")", indent)
1226+
elseif head === :new || head === :splatnew
1227+
show_enclosed_list(io, "%$head(", args, ", ", ")", indent)
12281228

12291229
# other call-like expressions ("A[1,2]", "T{X,Y}", "f.(X,Y)")
12301230
elseif haskey(expr_calls, head) && nargs >= 1 # :ref/:curly/:calldecl/:(.)

doc/src/devdocs/ast.md

+5
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,11 @@ These symbols appear in the `head` field of [`Expr`](@ref)s in lowered form.
359359
to this, and the type is always inserted by the compiler. This is very much an internal-only
360360
feature, and does no checking. Evaluating arbitrary `new` expressions can easily segfault.
361361

362+
* `splatnew`
363+
364+
Similar to `new`, except field values are passed as a single tuple. Works similarly to
365+
`Base.splat(new)` if `new` were a first-class function, hence the name.
366+
362367
* `return`
363368

364369
Returns its argument as the value of the enclosing function.

src/ast.c

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ jl_sym_t *enter_sym; jl_sym_t *leave_sym;
4141
jl_sym_t *pop_exception_sym;
4242
jl_sym_t *exc_sym; jl_sym_t *error_sym;
4343
jl_sym_t *new_sym; jl_sym_t *using_sym;
44+
jl_sym_t *splatnew_sym;
4445
jl_sym_t *const_sym; jl_sym_t *thunk_sym;
4546
jl_sym_t *abstracttype_sym; jl_sym_t *primtype_sym;
4647
jl_sym_t *structtype_sym; jl_sym_t *foreigncall_sym;
@@ -325,6 +326,7 @@ void jl_init_frontend(void)
325326
leave_sym = jl_symbol("leave");
326327
pop_exception_sym = jl_symbol("pop_exception");
327328
new_sym = jl_symbol("new");
329+
splatnew_sym = jl_symbol("splatnew");
328330
const_sym = jl_symbol("const");
329331
global_sym = jl_symbol("global");
330332
thunk_sym = jl_symbol("thunk");

src/codegen.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ static Function *jltls_states_func;
264264

265265
// important functions
266266
static Function *jlnew_func;
267+
static Function *jlsplatnew_func;
267268
static Function *jlthrow_func;
268269
static Function *jlerror_func;
269270
static Function *jltypeerror_func;
@@ -4069,6 +4070,15 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval)
40694070
// it to the inferred type.
40704071
return mark_julia_type(ctx, val, true, (jl_value_t*)jl_any_type);
40714072
}
4073+
else if (head == splatnew_sym) {
4074+
jl_cgval_t argv[2];
4075+
argv[0] = emit_expr(ctx, args[0]);
4076+
argv[1] = emit_expr(ctx, args[1]);
4077+
Value *typ = boxed(ctx, argv[0]);
4078+
Value *tup = boxed(ctx, argv[1]);
4079+
Value *val = ctx.builder.CreateCall(prepare_call(jlsplatnew_func), { typ, tup });
4080+
return mark_julia_type(ctx, val, true, (jl_value_t*)jl_any_type);
4081+
}
40724082
else if (head == exc_sym) {
40734083
return mark_julia_type(ctx,
40744084
ctx.builder.CreateCall(prepare_call(jl_current_exception_func)),
@@ -6981,6 +6991,17 @@ static void init_julia_llvm_env(Module *m)
69816991
jlnew_func->addFnAttr(Thunk);
69826992
add_named_global(jlnew_func, &jl_new_structv);
69836993

6994+
std::vector<Type *> args_2rptrs_(0);
6995+
args_2rptrs_.push_back(T_prjlvalue);
6996+
args_2rptrs_.push_back(T_prjlvalue);
6997+
jlsplatnew_func =
6998+
Function::Create(FunctionType::get(T_prjlvalue, args_2rptrs_, false),
6999+
Function::ExternalLinkage,
7000+
"jl_new_structt", m);
7001+
add_return_attr(jlsplatnew_func, Attribute::NonNull);
7002+
jlsplatnew_func->addFnAttr(Thunk);
7003+
add_named_global(jlsplatnew_func, &jl_new_structt);
7004+
69847005
std::vector<Type*> args2(0);
69857006
args2.push_back(T_pint8);
69867007
#ifndef _OS_WINDOWS_

src/datatype.c

+50-12
Original file line numberDiff line numberDiff line change
@@ -797,8 +797,24 @@ JL_DLLEXPORT jl_value_t *jl_new_struct(jl_datatype_t *type, ...)
797797
return jv;
798798
}
799799

800-
JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
801-
uint32_t na)
800+
static void init_struct_tail(jl_datatype_t *type, jl_value_t *jv, size_t na)
801+
{
802+
size_t nf = jl_datatype_nfields(type);
803+
for(size_t i=na; i < nf; i++) {
804+
if (jl_field_isptr(type, i)) {
805+
*(jl_value_t**)((char*)jl_data_ptr(jv)+jl_field_offset(type,i)) = NULL;
806+
}
807+
else {
808+
jl_value_t *ft = jl_field_type(type, i);
809+
if (jl_is_uniontype(ft)) {
810+
uint8_t *psel = &((uint8_t *)jv)[jl_field_offset(type, i) + jl_field_size(type, i) - 1];
811+
*psel = 0;
812+
}
813+
}
814+
}
815+
}
816+
817+
JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, uint32_t na)
802818
{
803819
jl_ptls_t ptls = jl_get_ptls_states();
804820
if (type->instance != NULL) {
@@ -811,7 +827,6 @@ JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
811827
}
812828
if (type->layout == NULL)
813829
jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type);
814-
size_t nf = jl_datatype_nfields(type);
815830
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type);
816831
JL_GC_PUSH1(&jv);
817832
for (size_t i = 0; i < na; i++) {
@@ -820,18 +835,41 @@ JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
820835
jl_type_error("new", ft, args[i]);
821836
jl_set_nth_field(jv, i, args[i]);
822837
}
823-
for(size_t i=na; i < nf; i++) {
824-
if (jl_field_isptr(type, i)) {
825-
*(jl_value_t**)((char*)jl_data_ptr(jv)+jl_field_offset(type,i)) = NULL;
826-
}
827-
else {
838+
init_struct_tail(type, jv, na);
839+
JL_GC_POP();
840+
return jv;
841+
}
842+
843+
JL_DLLEXPORT jl_value_t *jl_new_structt(jl_datatype_t *type, jl_value_t *tup)
844+
{
845+
jl_ptls_t ptls = jl_get_ptls_states();
846+
if (!jl_is_tuple(tup))
847+
jl_type_error("new", (jl_value_t*)jl_tuple_type, tup);
848+
size_t na = jl_nfields(tup);
849+
size_t nf = jl_datatype_nfields(type);
850+
if (na > nf)
851+
jl_too_many_args("new", nf);
852+
if (type->instance != NULL) {
853+
for (size_t i = 0; i < na; i++) {
828854
jl_value_t *ft = jl_field_type(type, i);
829-
if (jl_is_uniontype(ft)) {
830-
uint8_t *psel = &((uint8_t *)jv)[jl_field_offset(type, i) + jl_field_size(type, i) - 1];
831-
*psel = 0;
832-
}
855+
jl_value_t *fi = jl_get_nth_field(tup, i);
856+
if (!jl_isa(fi, ft))
857+
jl_type_error("new", ft, fi);
833858
}
859+
return type->instance;
860+
}
861+
if (type->layout == NULL)
862+
jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type);
863+
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type);
864+
JL_GC_PUSH1(&jv);
865+
for (size_t i = 0; i < na; i++) {
866+
jl_value_t *ft = jl_field_type(type, i);
867+
jl_value_t *fi = jl_get_nth_field(tup, i);
868+
if (!jl_isa(fi, ft))
869+
jl_type_error("new", ft, fi);
870+
jl_set_nth_field(jv, i, fi);
834871
}
872+
init_struct_tail(type, jv, na);
835873
JL_GC_POP();
836874
return jv;
837875
}

src/interpreter.c

+10
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,16 @@ SECT_INTERP static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
467467
JL_GC_POP();
468468
return v;
469469
}
470+
else if (head == splatnew_sym) {
471+
jl_value_t **argv;
472+
JL_GC_PUSHARGS(argv, 2);
473+
argv[0] = eval_value(args[0], s);
474+
argv[1] = eval_value(args[1], s);
475+
assert(jl_is_structtype(argv[0]));
476+
jl_value_t *v = jl_new_structt((jl_datatype_t*)argv[0], argv[1]);
477+
JL_GC_POP();
478+
return v;
479+
}
470480
else if (head == static_parameter_sym) {
471481
ssize_t n = jl_unbox_long(args[0]);
472482
assert(n > 0);

0 commit comments

Comments
 (0)