Skip to content

Commit 8b74afb

Browse files
committed
RFC: Change lowering of destructuring to avoid const prop dependence
I'm currently doing some work with inference passes that have const prop (temporarily) disabled and I noticed we actually rely on it quite a bit for basic things. That's not terrible - const prop works pretty well after all, but it still imposes a cost and while I want to support it in my AD use case also, it makes destructuring quite expensive, because everything needs to be inferred twice. This PR is an experiment in changing the lowering to avoid having to const prop the index. Rather than lowering `(a,b,c) = foo()` as: ``` it = foo() a, s = indexed_iterate(it, 1) b, s = indexed_iterate(it, 2) c, s = indexed_iterate(it, 3) ``` we lower as: ``` it = foo() iterate, index = iterate_and_index(it) x = iterate(it) a = index(x, 1) y = iterate(it, y) b = index(y, 2) z = iterate(it, z) c = index(z, 3) ``` For tuples `iterate` would simply return the first argument and `index` would be `getfield`. That way, there is no const prop, since `getfield` is called directly and inference can directly use its tfunc. For the fallback case `iterate` is basically just `Base.iterate`, with just a slight tweak to give an intelligent error for short iterables. On simple functions, there isn't much of a difference in execution time, but benchmarking something more complicated like: ``` function g() a, = getfield(((1,),(2.0,3),("x",),(:x,)), Base.inferencebarrier(1)) nothing end ``` shows about a 20% improvement in end-to-end inference/optimize time, which is substantial.
1 parent 6de97d5 commit 8b74afb

File tree

10 files changed

+65
-30
lines changed

10 files changed

+65
-30
lines changed

base/missing.jl

+2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ convert(::Type{T}, x::T) where {T>:Union{Missing, Nothing}} = x
6969
convert(::Type{T}, x) where {T>:Missing} = convert(nonmissingtype_checked(T), x)
7070
convert(::Type{T}, x) where {T>:Union{Missing, Nothing}} = convert(nonmissingtype_checked(nonnothingtype_checked(T)), x)
7171

72+
# Hoisting this MethodError to `iterate_and_index` makes inference's job easier
73+
iterate_and_index(::Missing) = throw(MethodError(iterate, (missing,)))
7274

7375
# Comparison operators
7476
==(::Missing, ::Missing) = missing

base/namedtuple.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ firstindex(t::NamedTuple) = 1
112112
lastindex(t::NamedTuple) = nfields(t)
113113
getindex(t::NamedTuple, i::Int) = getfield(t, i)
114114
getindex(t::NamedTuple, i::Symbol) = getfield(t, i)
115-
indexed_iterate(t::NamedTuple, i::Int, state=1) = (getfield(t, i), i+1)
116115
isempty(::NamedTuple{()}) = true
117116
isempty(::NamedTuple) = false
118117
empty(::NamedTuple) = NamedTuple()
118+
index_and_iterate(t::NamedTuple) = (arg1, getfield)
119119

120120
convert(::Type{NamedTuple{names,T}}, nt::NamedTuple{names,T}) where {names,T<:Tuple} = nt
121121
convert(::Type{NamedTuple{names}}, nt::NamedTuple{names}) where {names} = nt

base/pair.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Pair, =>
4747

4848
eltype(p::Type{Pair{A, B}}) where {A, B} = Union{A, B}
4949
iterate(p::Pair, i=1) = i > 2 ? nothing : (getfield(p, i), i + 1)
50-
indexed_iterate(p::Pair, i::Int, state=1) = (getfield(p, i), i + 1)
50+
iterate_and_index(p::Pair) = (arg1, getfield)
5151

5252
hash(p::Pair, h::UInt) = hash(p.second, hash(p.first, h))
5353

base/tuple.jl

+33-13
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,41 @@ function _maxlength(t::Tuple, t2::Tuple, t3::Tuple...)
8181
max(length(t), _maxlength(t2, t3...))
8282
end
8383

84-
# this allows partial evaluation of bounded sequences of next() calls on tuples,
85-
# while reducing to plain next() for arbitrary iterables.
86-
indexed_iterate(t::Tuple, i::Int, state=1) = (@_inline_meta; (getfield(t, i), i+1))
87-
indexed_iterate(a::Array, i::Int, state=1) = (@_inline_meta; (a[i], i+1))
88-
function indexed_iterate(I, i)
89-
x = iterate(I)
90-
x === nothing && throw(BoundsError(I, i))
91-
x
92-
end
93-
function indexed_iterate(I, i, state)
94-
x = iterate(I, state)
95-
x === nothing && throw(BoundsError(I, i))
96-
x
84+
# this allows partial evaluation of bounded sequences of iterate() calls on tuples,
85+
# while reducing to plain iterate() for arbitrary iterables.
86+
87+
arg1(a) = a
88+
arg1(a, b) = a
89+
iterate_and_index(t::Tuple) = (arg1, getfield)
90+
iterate_and_index(t::Array) = (arg1, getindex)
91+
92+
struct BadDestructure
93+
a
94+
end
95+
96+
function destruct_iterate(a)
97+
@_inline_meta
98+
s = iterate(a)
99+
s === nothing && return BadDestructure(a)
100+
s
97101
end
98102

103+
function destruct_iterate(a, b)
104+
@_inline_meta
105+
s = iterate(a, getfield(b, 2))
106+
s === nothing && return BadDestructure(a)
107+
s
108+
end
109+
110+
select_first(a::BadDestructure, i) = throw(BoundsError(a.a, i))
111+
select_first(a, i) = getfield(a, 1)
112+
113+
iterate_and_index(x) = (destruct_iterate, select_first)
114+
115+
# Nothing is often union'ed into other things. Kill that as quickly as possible
116+
# to make inference's life easier.
117+
iterate_and_index(::Nothing) = throw(MethodError(iterate, (nothing,)))
118+
99119
# Use dispatch to avoid a branch in first
100120
first(::Tuple{}) = throw(ArgumentError("tuple must be non-empty"))
101121
first(t::Tuple) = t[1]

src/common_symbols1.inc

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jl_symbol("*"),
3333
jl_symbol("bitcast"),
3434
jl_symbol("slt_int"),
3535
jl_symbol("isempty"),
36-
jl_symbol("indexed_iterate"),
36+
jl_symbol("iterate_and_index"),
3737
jl_symbol("size"),
3838
jl_symbol("!"),
3939
jl_symbol("nothing"),

src/julia-syntax.scm

+16-7
Original file line numberDiff line numberDiff line change
@@ -2049,17 +2049,26 @@
20492049
x (make-ssavalue)))
20502050
(ini (if (eq? x xx) '() (list (sink-assignment xx (expand-forms x)))))
20512051
(n (length lhss))
2052+
(funcs (make-ssavalue))
2053+
(iterate (make-ssavalue))
2054+
(index (make-ssavalue))
20522055
(st (gensy)))
20532056
`(block
20542057
,@ini
2058+
,(lower-tuple-assignment
2059+
(list iterate index)
2060+
`(call (top iterate_and_index) ,xx))
20552061
,.(map (lambda (i lhs)
2056-
(expand-forms
2057-
(lower-tuple-assignment
2058-
(if (= i (- n 1))
2059-
(list lhs)
2060-
(list lhs st))
2061-
`(call (top indexed_iterate)
2062-
,xx ,(+ i 1) ,.(if (eq? i 0) '() `(,st))))))
2062+
(expand-forms
2063+
`(block
2064+
(= ,st (call ,iterate
2065+
,xx ,.(if (eq? i 0) '() `(,st))))
2066+
,(if (eventually-call? lhs)
2067+
(let ((val (gensy)))
2068+
`(block
2069+
(= ,val (call ,index ,st ,(+ i 1)))
2070+
(= ,lhs ,val)))
2071+
`(= ,lhs (call ,index ,st ,(+ i 1)))))))
20632072
(iota n)
20642073
lhss)
20652074
(unnecessary ,xx))))))

stdlib/Serialization/src/Serialization.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Serializer(io::IO) = Serializer{typeof(io)}(io)
3030
## serializing values ##
3131

3232
const n_int_literals = 33
33-
const n_reserved_slots = 24
33+
const n_reserved_slots = 23
3434
const n_reserved_tags = 8
3535

3636
const TAGS = Any[
@@ -69,6 +69,7 @@ const TAGS = Any[
6969
:indexed_iterate, :getfield, :meta, :eq_int, :slt_int, :sle_int, :ne_int, :push_loc, :pop_loc,
7070
:pop, :arrayset, :arrayref, :apply_type, :inbounds, :getindex, :setindex!, :Core, :!, :+,
7171
:Base, :static_parameter, :convert, :colon, Symbol("#self#"), Symbol("#temp#"), :tuple, Symbol(""),
72+
:iterate_and_index,
7273

7374
fill(:_reserved_, n_reserved_slots)...,
7475

test/compiler/inference.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -562,10 +562,7 @@ end
562562

563563
function g19348(x)
564564
a, b = x
565-
g = 1
566-
g = 2
567-
c = Base.indexed_iterate(x, g, g)
568-
return a + b + c[1]
565+
return a + b
569566
end
570567

571568
for (codetype, all_ssa) in Any[

test/core.jl

+8
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,14 @@ let (f(), x) = (1, 2)
557557
@test x == 2
558558
end
559559

560+
foo23091_cnt = 0
561+
struct Foo23091; end
562+
Base.iterate(::Foo23091, state...) = (global foo23091_cnt += 1; (1, nothing))
563+
(g23091(), h23091()) = Foo23091()
564+
@test foo23091_cnt == 2
565+
g23091(); h23091()
566+
@test foo23091_cnt == 2
567+
560568
# issue #21900
561569
f21900_cnt = 0
562570
function f21900()

test/dict.jl

-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ using Random
1111
@test iterate(p, iterate(p, iterate(p)[2])[2]) == nothing
1212
@test firstindex(p) == 1
1313
@test lastindex(p) == length(p) == 2
14-
@test Base.indexed_iterate(p, 1, nothing) == (10,2)
15-
@test Base.indexed_iterate(p, 2, nothing) == (20,3)
1614
@test (1=>2) < (2=>3)
1715
@test (2=>2) < (2=>3)
1816
@test !((2=>3) < (2=>3))

0 commit comments

Comments
 (0)