Skip to content

Commit d515f05

Browse files
authored
optimizer: refactor SROA pass (#43232)
- avoid domtree construction when there are no mutables to eliminate - reduce # of dynamic allocations - separate some computations into individual functions
1 parent 5c357e9 commit d515f05

File tree

1 file changed

+91
-82
lines changed

1 file changed

+91
-82
lines changed

base/compiler/ssair/passes.jl

+91-82
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])
2323

2424
compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses)
2525

26-
function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
26+
function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr)
2727
field = stmt.args[3]
2828
# fields are usually literals, handle them manually
2929
if isa(field, QuoteNode)
3030
field = field.value
3131
elseif isa(field, Int)
3232
# try to resolve other constants, e.g. global reference
3333
else
34-
field = compact_exprtype(compact, field)
34+
field = isa(ir, IncrementalCompact) ? compact_exprtype(ir, field) : argextype(field, ir)
3535
if isa(field, Const)
3636
field = field.val
3737
else
@@ -42,8 +42,8 @@ function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
4242
return field
4343
end
4444

45-
function try_compute_fieldidx_stmt(compact::IncrementalCompact, stmt::Expr, typ::DataType)
46-
field = try_compute_field_stmt(compact, stmt)
45+
function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType)
46+
field = try_compute_field_stmt(ir, stmt)
4747
return try_compute_fieldidx(typ, field)
4848
end
4949

@@ -112,6 +112,13 @@ function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
112112
return def, stmtblock, curblock
113113
end
114114

115+
function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint))
116+
if isa(val, Union{OldSSAValue, SSAValue})
117+
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
118+
end
119+
return walk_to_defs(compact, val, typeconstraint)
120+
end
121+
115122
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
116123
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
117124
while true
@@ -152,7 +159,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
152159
end
153160

154161
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
155-
@nospecialize(typeconstraint = types(compact)[defssa]))
162+
@nospecialize(typeconstraint))
156163
callback = function (@nospecialize(pi), @nospecialize(idx))
157164
if isa(pi, PiNode)
158165
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
@@ -164,20 +171,16 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss
164171
end
165172

166173
"""
167-
walk_to_defs(compact, val, intermediaries)
174+
walk_to_defs(compact, val, typeconstraint)
168175
169-
Starting at `val` walk use-def chains to get all the leaves feeding into
170-
this val (pruning those leaves rules out by path conditions).
176+
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
177+
(pruning those leaves rules out by path conditions).
171178
"""
172-
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint), visited_phinodes::Vector{AnySSAValue}=AnySSAValue[])
173-
isa(defssa, AnySSAValue) || return Any[defssa]
179+
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
180+
visited_phinodes = AnySSAValue[]
181+
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
174182
def = compact[defssa]
175-
isa(def, PhiNode) || return Any[defssa]
176-
# Step 2: Figure out what the struct is defined as
177-
## Track definitions through PiNode/PhiNode
178-
found_def = false
179-
## Track which PhiNodes, SSAValue intermediaries
180-
## we forwarded through.
183+
isa(def, PhiNode) || return Any[defssa], visited_phinodes
181184
visited_constraints = IdDict{AnySSAValue, Any}()
182185
worklist_defs = AnySSAValue[]
183186
worklist_constraints = Any[]
@@ -239,10 +242,10 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
239242
push!(leaves, defssa)
240243
end
241244
end
242-
leaves
245+
return leaves, visited_phinodes
243246
end
244247

245-
function process_immutable_preserve(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr)
248+
function process_immutable_preserve!(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr)
246249
for arg in (isexpr(def, :new) ? def.args : def.args[2:end])
247250
if !isbitstype(widenconst(compact_exprtype(compact, arg)))
248251
push!(new_preserves, arg)
@@ -449,13 +452,10 @@ function lift_comparison!(compact::IncrementalCompact,
449452
return
450453
end
451454

452-
if isa(val, Union{OldSSAValue, SSAValue})
453-
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
454-
end
455-
456-
visited_phinodes = AnySSAValue[]
457-
leaves = walk_to_defs(compact, val, typeconstraint, visited_phinodes)
455+
valtyp = widenconst(compact_exprtype(compact, val))
456+
isa(valtyp, Union) || return # bail out if there won't be a good chance for lifting
458457

458+
leaves, visited_phinodes = collect_leaves(compact, val, valtyp)
459459
length(leaves) 1 && return # bail out if we don't have multiple leaves
460460

461461
# Let's check if we evaluate the comparison for each one of the leaves
@@ -476,10 +476,6 @@ function lift_comparison!(compact::IncrementalCompact,
476476
visited_phinodes, cmp, lifting_cache, Bool,
477477
lifted_leaves::IdDict{Any, Union{Nothing,LiftedValue}}, val)::LiftedValue
478478

479-
# global assertion_counter
480-
# assertion_counter::Int += 1
481-
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true)
482-
# return
483479
compact[idx] = lifted_val.x
484480
end
485481

@@ -576,6 +572,10 @@ function perform_lifting!(compact::IncrementalCompact,
576572
return stmt_val # N.B. should never happen
577573
end
578574

575+
# NOTE we use `IdSet{Int}` instead of `BitSet` for `sroa_pass!` since it works on IR after inlining,
576+
# which can be very large sometimes, and analyzed program counters are often very sparse
577+
const SPCSet = IdSet{Int}
578+
579579
"""
580580
sroa_pass!(ir::IRCode) -> newir::IRCode
581581
@@ -596,17 +596,16 @@ a result of succeeding dead code elimination.
596596
"""
597597
function sroa_pass!(ir::IRCode)
598598
compact = IncrementalCompact(ir)
599-
defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}()
599+
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
600600
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
601601
for ((_, idx), stmt) in compact
602+
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
602603
isa(stmt, Expr) || continue
603-
result_t = compact_exprtype(compact, SSAValue(idx))
604604
is_setfield = false
605605
field_ordering = :unspecified
606-
# Step 1: Check whether the statement we're looking at is a getfield/setfield!
607606
if is_known_call(stmt, setfield!, compact)
608-
is_setfield = true
609607
4 <= length(stmt.args) <= 5 || continue
608+
is_setfield = true
610609
if length(stmt.args) == 5
611610
field_ordering = compact_exprtype(compact, stmt.args[5])
612611
end
@@ -624,7 +623,7 @@ function sroa_pass!(ir::IRCode)
624623
old_preserves = stmt.args[(6+nccallargs):end]
625624
for (pidx, preserved_arg) in enumerate(old_preserves)
626625
isa(preserved_arg, SSAValue) || continue
627-
let intermediaries = IdSet{Int}()
626+
let intermediaries = SPCSet()
628627
callback = function (@nospecialize(pi), @nospecialize(ssa))
629628
push!(intermediaries, ssa.id)
630629
return false
@@ -634,7 +633,7 @@ function sroa_pass!(ir::IRCode)
634633
defidx = def.id
635634
def = compact[defidx]
636635
if is_tuple_call(compact, def)
637-
process_immutable_preserve(new_preserves, compact, def)
636+
process_immutable_preserve!(new_preserves, compact, def)
638637
old_preserves[pidx] = nothing
639638
continue
640639
elseif isexpr(def, :new)
@@ -643,14 +642,17 @@ function sroa_pass!(ir::IRCode)
643642
typ = unwrap_unionall(typ)
644643
end
645644
if typ isa DataType && !ismutabletype(typ)
646-
process_immutable_preserve(new_preserves, compact, def)
645+
process_immutable_preserve!(new_preserves, compact, def)
647646
old_preserves[pidx] = nothing
648647
continue
649648
end
650649
else
651650
continue
652651
end
653-
mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse()))
652+
if defuses === nothing
653+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
654+
end
655+
mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse()))
654656
push!(defuse.ccall_preserve_uses, idx)
655657
union!(mid, intermediaries)
656658
end
@@ -675,10 +677,15 @@ function sroa_pass!(ir::IRCode)
675677
else
676678
continue
677679
end
680+
681+
# analyze this `getfield` / `setfield!` call
682+
678683
field = try_compute_field_stmt(compact, stmt)
679684
field === nothing && continue
680685

681-
struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, stmt.args[2])))
686+
val = stmt.args[2]
687+
688+
struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, val)))
682689
if isa(struct_typ, Union) && struct_typ <: Tuple
683690
struct_typ = unswitchtupleunion(struct_typ)
684691
end
@@ -689,19 +696,21 @@ function sroa_pass!(ir::IRCode)
689696
continue
690697
end
691698

692-
def, typeconstraint = stmt.args[2], struct_typ
693-
699+
# analyze this mutable struct here for the later pass
694700
if ismutabletype(struct_typ)
695-
isa(def, SSAValue) || continue
696-
let intermediaries = IdSet{Int}()
701+
isa(val, SSAValue) || continue
702+
let intermediaries = SPCSet()
697703
callback = function (@nospecialize(pi), @nospecialize(ssa))
698704
push!(intermediaries, ssa.id)
699705
return false
700706
end
701-
def = simple_walk(compact, def, callback)
707+
def = simple_walk(compact, val, callback)
702708
# Mutable stuff here
703709
isa(def, SSAValue) || continue
704-
mid, defuse = get!(defuses, def.id, (IdSet{Int}(), SSADefUse()))
710+
if defuses === nothing
711+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
712+
end
713+
mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse()))
705714
if is_setfield
706715
push!(defuse.defs, idx)
707716
else
@@ -711,32 +720,28 @@ function sroa_pass!(ir::IRCode)
711720
end
712721
continue
713722
elseif is_setfield
714-
continue
723+
continue # invalid `setfield!` call, but just ignore here
715724
end
716725

717726
# perform SROA on immutable structs here on
718727

719-
if isa(def, Union{OldSSAValue, SSAValue})
720-
def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint)
721-
end
722-
723-
visited_phinodes = AnySSAValue[]
724-
leaves = walk_to_defs(compact, def, typeconstraint, visited_phinodes)
725-
726-
isempty(leaves) && continue
727-
728728
field = try_compute_fieldidx(struct_typ, field)
729729
field === nothing && continue
730730

731-
r = lift_leaves(compact, result_t, field, leaves)
732-
r === nothing && continue
733-
lifted_leaves, any_undef = r
731+
leaves, visited_phinodes = collect_leaves(compact, val, struct_typ)
732+
isempty(leaves) && continue
733+
734+
result_t = compact_exprtype(compact, SSAValue(idx))
735+
lifted_result = lift_leaves(compact, result_t, field, leaves)
736+
lifted_result === nothing && continue
737+
lifted_leaves, any_undef = lifted_result
734738

735739
if any_undef
736740
result_t = make_MaybeUndef(result_t)
737741
end
738742

739-
val = perform_lifting!(compact, visited_phinodes, field, lifting_cache, result_t, lifted_leaves, stmt.args[2])
743+
val = perform_lifting!(compact,
744+
visited_phinodes, field, lifting_cache, result_t, lifted_leaves, val)
740745

741746
# Insert the undef check if necessary
742747
if any_undef
@@ -750,28 +755,32 @@ function sroa_pass!(ir::IRCode)
750755
@assert val !== nothing
751756
end
752757

753-
# global assertion_counter
754-
# assertion_counter::Int += 1
755-
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
756-
# continue
757758
compact[idx] = val === nothing ? nothing : val.x
758759
end
759760

760761
non_dce_finish!(compact)
761-
# Copy the use count, `simple_dce!` may modify it and for our predicate
762-
# below we need it consistent with the state of the IR here (after tracking
763-
# phi node arguments, but before dce).
764-
used_ssas = copy(compact.used_ssas)
765-
simple_dce!(compact)
766-
ir = complete(compact)
767-
768-
# Compute domtree, needed below, now that we have finished compacting the
769-
# IR. This needs to be after we iterate through the IR with
770-
# `IncrementalCompact` because removing dead blocks can invalidate the
771-
# domtree.
762+
if defuses !== nothing
763+
# now go through analyzed mutable structs and see which ones we can eliminate
764+
# NOTE copy the use count here, because `simple_dce!` may modify it and we need it
765+
# consistent with the state of the IR here (after tracking `PhiNode` arguments,
766+
# but before the DCE) for our predicate within `sroa_mutables!`
767+
used_ssas = copy(compact.used_ssas)
768+
simple_dce!(compact)
769+
ir = complete(compact)
770+
sroa_mutables!(ir, defuses, used_ssas)
771+
return ir
772+
else
773+
simple_dce!(compact)
774+
return complete(compact)
775+
end
776+
end
777+
778+
function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int})
779+
# Compute domtree, needed below, now that we have finished compacting the IR.
780+
# This needs to be after we iterate through the IR with `IncrementalCompact`
781+
# because removing dead blocks can invalidate the domtree.
772782
@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)
773783

774-
# Now go through any mutable structs and see which ones we can eliminate
775784
for (idx, (intermediaries, defuse)) in defuses
776785
intermediaries = collect(intermediaries)
777786
# Check if there are any uses we did not account for. If so, the variable
@@ -806,12 +815,12 @@ function sroa_pass!(ir::IRCode)
806815
# it would have been deleted. That's fine, just ignore
807816
# the use in that case.
808817
stmt === nothing && continue
809-
field = try_compute_fieldidx_stmt(compact, stmt::Expr, typ)
818+
field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ)
810819
field === nothing && @goto skip
811820
push!(fielddefuse[field].uses, use)
812821
end
813822
for use in defuse.defs
814-
field = try_compute_fieldidx_stmt(compact, ir[SSAValue(use)]::Expr, typ)
823+
field = try_compute_fieldidx_stmt(ir, ir[SSAValue(use)]::Expr, typ)
815824
field === nothing && @goto skip
816825
push!(fielddefuse[field].defs, use)
817826
end
@@ -846,8 +855,9 @@ function sroa_pass!(ir::IRCode)
846855
end
847856
end
848857
end
849-
preserve_uses = IdDict{Int, Vector{Any}}((idx=>Any[] for idx in IdSet{Int}(defuse.ccall_preserve_uses)))
850858
# Everything accounted for. Go field by field and perform idf
859+
preserve_uses = isempty(defuse.ccall_preserve_uses) ? nothing :
860+
IdDict{Int, Vector{Any}}((idx=>Any[] for idx in SPCSet(defuse.ccall_preserve_uses)))
851861
for fidx in 1:ndefuse
852862
du = fielddefuse[fidx]
853863
ftyp = fieldtype(typ, fidx)
@@ -863,8 +873,10 @@ function sroa_pass!(ir::IRCode)
863873
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
864874
end
865875
if !isbitstype(ftyp)
866-
for (use, list) in preserve_uses
867-
push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
876+
if preserve_uses !== nothing
877+
for (use, list) in preserve_uses
878+
push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
879+
end
868880
end
869881
end
870882
for b in phiblocks
@@ -881,7 +893,7 @@ function sroa_pass!(ir::IRCode)
881893
ir[SSAValue(stmt)] = nothing
882894
end
883895
end
884-
isempty(defuse.ccall_preserve_uses) && continue
896+
preserve_uses === nothing && continue
885897
push!(intermediaries, newidx)
886898
# Insert the new preserves
887899
for (use, new_preserves) in preserve_uses
@@ -897,10 +909,7 @@ function sroa_pass!(ir::IRCode)
897909

898910
@label skip
899911
end
900-
901-
return ir
902912
end
903-
# assertion_counter = 0
904913

905914
"""
906915
canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::Expr)

0 commit comments

Comments
 (0)