Skip to content

Commit f0e684a

Browse files
committed
optimizer: refactor SROA pass
- use `BitSet` instead of `IdSet{Int}` - reduce # of dynamic allocations - separate some computations into individual functions
1 parent df0080c commit f0e684a

File tree

1 file changed

+83
-79
lines changed

1 file changed

+83
-79
lines changed

base/compiler/ssair/passes.jl

+83-79
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])
2323

2424
compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses)
2525

26-
function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
26+
function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr)
2727
field = stmt.args[3]
2828
# fields are usually literals, handle them manually
2929
if isa(field, QuoteNode)
3030
field = field.value
3131
elseif isa(field, Int)
3232
# try to resolve other constants, e.g. global reference
3333
else
34-
field = compact_exprtype(compact, field)
34+
field = isa(ir, IncrementalCompact) ? compact_exprtype(ir, field) : argextype(field, ir)
3535
if isa(field, Const)
3636
field = field.val
3737
else
@@ -42,8 +42,8 @@ function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
4242
return field
4343
end
4444

45-
function try_compute_fieldidx_stmt(compact::IncrementalCompact, stmt::Expr, typ::DataType)
46-
field = try_compute_field_stmt(compact, stmt)
45+
function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType)
46+
field = try_compute_field_stmt(ir, stmt)
4747
return try_compute_fieldidx(typ, field)
4848
end
4949

@@ -112,6 +112,13 @@ function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
112112
return def, stmtblock, curblock
113113
end
114114

115+
function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint))
116+
if isa(val, Union{OldSSAValue, SSAValue})
117+
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
118+
end
119+
return walk_to_defs(compact, val, typeconstraint)
120+
end
121+
115122
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
116123
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
117124
while true
@@ -152,7 +159,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
152159
end
153160

154161
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
155-
@nospecialize(typeconstraint = types(compact)[defssa]))
162+
@nospecialize(typeconstraint))
156163
callback = function (@nospecialize(pi), @nospecialize(idx))
157164
if isa(pi, PiNode)
158165
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
@@ -164,20 +171,16 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss
164171
end
165172

166173
"""
167-
walk_to_defs(compact, val, intermediaries)
174+
walk_to_defs(compact, val, typeconstraint)
168175
169-
Starting at `val` walk use-def chains to get all the leaves feeding into
170-
this val (pruning those leaves rules out by path conditions).
176+
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
177+
(pruning those leaves rules out by path conditions).
171178
"""
172-
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint), visited_phinodes::Vector{AnySSAValue}=AnySSAValue[])
173-
isa(defssa, AnySSAValue) || return Any[defssa]
179+
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
180+
visited_phinodes = AnySSAValue[]
181+
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
174182
def = compact[defssa]
175-
isa(def, PhiNode) || return Any[defssa]
176-
# Step 2: Figure out what the struct is defined as
177-
## Track definitions through PiNode/PhiNode
178-
found_def = false
179-
## Track which PhiNodes, SSAValue intermediaries
180-
## we forwarded through.
183+
isa(def, PhiNode) || return Any[defssa], visited_phinodes
181184
visited_constraints = IdDict{AnySSAValue, Any}()
182185
worklist_defs = AnySSAValue[]
183186
worklist_constraints = Any[]
@@ -239,10 +242,10 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
239242
push!(leaves, defssa)
240243
end
241244
end
242-
leaves
245+
return leaves, visited_phinodes
243246
end
244247

245-
function process_immutable_preserve(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr)
248+
function process_immutable_preserve!(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr)
246249
for arg in (isexpr(def, :new) ? def.args : def.args[2:end])
247250
if !isbitstype(widenconst(compact_exprtype(compact, arg)))
248251
push!(new_preserves, arg)
@@ -449,13 +452,10 @@ function lift_comparison!(compact::IncrementalCompact,
449452
return
450453
end
451454

452-
if isa(val, Union{OldSSAValue, SSAValue})
453-
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
454-
end
455-
456-
visited_phinodes = AnySSAValue[]
457-
leaves = walk_to_defs(compact, val, typeconstraint, visited_phinodes)
455+
valtyp = widenconst(compact_exprtype(compact, val))
456+
isa(valtyp, Union) || return # bail out if there won't be a good chance for lifting
458457

458+
leaves, visited_phinodes = collect_leaves(compact, val, valtyp)
459459
length(leaves) 1 && return # bail out if we don't have multiple leaves
460460

461461
# Let's check if we evaluate the comparison for each one of the leaves
@@ -476,10 +476,6 @@ function lift_comparison!(compact::IncrementalCompact,
476476
visited_phinodes, cmp, lifting_cache, Bool,
477477
lifted_leaves::IdDict{Any, Union{Nothing,LiftedValue}}, val)::LiftedValue
478478

479-
# global assertion_counter
480-
# assertion_counter::Int += 1
481-
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true)
482-
# return
483479
compact[idx] = lifted_val.x
484480
end
485481

@@ -596,17 +592,16 @@ a result of succeeding dead code elimination.
596592
"""
597593
function sroa_pass!(ir::IRCode)
598594
compact = IncrementalCompact(ir)
599-
defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}()
595+
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
600596
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
601597
for ((_, idx), stmt) in compact
598+
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
602599
isa(stmt, Expr) || continue
603-
result_t = compact_exprtype(compact, SSAValue(idx))
604600
is_setfield = false
605601
field_ordering = :unspecified
606-
# Step 1: Check whether the statement we're looking at is a getfield/setfield!
607602
if is_known_call(stmt, setfield!, compact)
608-
is_setfield = true
609603
4 <= length(stmt.args) <= 5 || continue
604+
is_setfield = true
610605
if length(stmt.args) == 5
611606
field_ordering = compact_exprtype(compact, stmt.args[5])
612607
end
@@ -624,7 +619,7 @@ function sroa_pass!(ir::IRCode)
624619
old_preserves = stmt.args[(6+nccallargs):end]
625620
for (pidx, preserved_arg) in enumerate(old_preserves)
626621
isa(preserved_arg, SSAValue) || continue
627-
let intermediaries = IdSet{Int}()
622+
let intermediaries = BitSet()
628623
callback = function (@nospecialize(pi), @nospecialize(ssa))
629624
push!(intermediaries, ssa.id)
630625
return false
@@ -634,7 +629,7 @@ function sroa_pass!(ir::IRCode)
634629
defidx = def.id
635630
def = compact[defidx]
636631
if is_tuple_call(compact, def)
637-
process_immutable_preserve(new_preserves, compact, def)
632+
process_immutable_preserve!(new_preserves, compact, def)
638633
old_preserves[pidx] = nothing
639634
continue
640635
elseif isexpr(def, :new)
@@ -643,14 +638,17 @@ function sroa_pass!(ir::IRCode)
643638
typ = unwrap_unionall(typ)
644639
end
645640
if typ isa DataType && !ismutabletype(typ)
646-
process_immutable_preserve(new_preserves, compact, def)
641+
process_immutable_preserve!(new_preserves, compact, def)
647642
old_preserves[pidx] = nothing
648643
continue
649644
end
650645
else
651646
continue
652647
end
653-
mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse()))
648+
if defuses === nothing
649+
defuses = IdDict{Int, Tuple{BitSet, SSADefUse}}()
650+
end
651+
mid, defuse = get!(defuses, defidx, (BitSet(), SSADefUse()))
654652
push!(defuse.ccall_preserve_uses, idx)
655653
union!(mid, intermediaries)
656654
end
@@ -675,6 +673,9 @@ function sroa_pass!(ir::IRCode)
675673
else
676674
continue
677675
end
676+
677+
# analyze this `getfield` / `setfield!` call
678+
678679
field = try_compute_field_stmt(compact, stmt)
679680
field === nothing && continue
680681

@@ -689,19 +690,23 @@ function sroa_pass!(ir::IRCode)
689690
continue
690691
end
691692

692-
def, typeconstraint = stmt.args[2], struct_typ
693+
val = stmt.args[2]
693694

695+
# analyze this mutable struct here for the later pass
694696
if ismutabletype(struct_typ)
695-
isa(def, SSAValue) || continue
696-
let intermediaries = IdSet{Int}()
697+
isa(val, SSAValue) || continue
698+
let intermediaries = BitSet()
697699
callback = function (@nospecialize(pi), @nospecialize(ssa))
698700
push!(intermediaries, ssa.id)
699701
return false
700702
end
701-
def = simple_walk(compact, def, callback)
703+
def = simple_walk(compact, val, callback)
702704
# Mutable stuff here
703705
isa(def, SSAValue) || continue
704-
mid, defuse = get!(defuses, def.id, (IdSet{Int}(), SSADefUse()))
706+
if defuses === nothing
707+
defuses = IdDict{Int, Tuple{BitSet, SSADefUse}}()
708+
end
709+
mid, defuse = get!(defuses, def.id, (BitSet(), SSADefUse()))
705710
if is_setfield
706711
push!(defuse.defs, idx)
707712
else
@@ -711,26 +716,21 @@ function sroa_pass!(ir::IRCode)
711716
end
712717
continue
713718
elseif is_setfield
714-
continue
719+
continue # invalid `setfield!` call, but just ignore here
715720
end
716721

717722
# perform SROA on immutable structs here on
718723

719-
if isa(def, Union{OldSSAValue, SSAValue})
720-
def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint)
721-
end
722-
723-
visited_phinodes = AnySSAValue[]
724-
leaves = walk_to_defs(compact, def, typeconstraint, visited_phinodes)
725-
726-
isempty(leaves) && continue
727-
728724
field = try_compute_fieldidx(struct_typ, field)
729725
field === nothing && continue
730726

731-
r = lift_leaves(compact, result_t, field, leaves)
732-
r === nothing && continue
733-
lifted_leaves, any_undef = r
727+
leaves, visited_phinodes = collect_leaves(compact, val, struct_typ)
728+
isempty(leaves) && continue
729+
730+
result_t = compact_exprtype(compact, SSAValue(idx))
731+
lifted_result = lift_leaves(compact, result_t, field, leaves)
732+
lifted_result === nothing && continue
733+
lifted_leaves, any_undef = lifted_result
734734

735735
if any_undef
736736
result_t = make_MaybeUndef(result_t)
@@ -750,28 +750,32 @@ function sroa_pass!(ir::IRCode)
750750
@assert val !== nothing
751751
end
752752

753-
# global assertion_counter
754-
# assertion_counter::Int += 1
755-
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
756-
# continue
757753
compact[idx] = val === nothing ? nothing : val.x
758754
end
759755

760756
non_dce_finish!(compact)
761-
# Copy the use count, `simple_dce!` may modify it and for our predicate
762-
# below we need it consistent with the state of the IR here (after tracking
763-
# phi node arguments, but before dce).
764-
used_ssas = copy(compact.used_ssas)
765-
simple_dce!(compact)
766-
ir = complete(compact)
767-
768-
# Compute domtree, needed below, now that we have finished compacting the
769-
# IR. This needs to be after we iterate through the IR with
770-
# `IncrementalCompact` because removing dead blocks can invalidate the
771-
# domtree.
757+
if defuses !== nothing
758+
# now go through analyzed mutable structs and see which ones we can eliminate
759+
# NOTE copy the use count here, because `simple_dce!` may modify it and we need it
760+
# consistent with the state of the IR here (after tracking `PhiNode` arguments,
761+
# but before the DCE) for our predicate within `sroa_mutables!`
762+
used_ssas = copy(compact.used_ssas)
763+
simple_dce!(compact)
764+
ir = complete(compact)
765+
sroa_mutables!(ir, defuses, used_ssas)
766+
return ir
767+
else
768+
simple_dce!(compact)
769+
return complete(compact)
770+
end
771+
end
772+
773+
function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{BitSet, SSADefUse}}, used_ssas::Vector{Int})
774+
# Compute domtree, needed below, now that we have finished compacting the IR.
775+
# This needs to be after we iterate through the IR with `IncrementalCompact`
776+
# because removing dead blocks can invalidate the domtree.
772777
@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)
773778

774-
# Now go through any mutable structs and see which ones we can eliminate
775779
for (idx, (intermediaries, defuse)) in defuses
776780
intermediaries = collect(intermediaries)
777781
# Check if there are any uses we did not account for. If so, the variable
@@ -806,12 +810,12 @@ function sroa_pass!(ir::IRCode)
806810
# it would have been deleted. That's fine, just ignore
807811
# the use in that case.
808812
stmt === nothing && continue
809-
field = try_compute_fieldidx_stmt(compact, stmt::Expr, typ)
813+
field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ)
810814
field === nothing && @goto skip
811815
push!(fielddefuse[field].uses, use)
812816
end
813817
for use in defuse.defs
814-
field = try_compute_fieldidx_stmt(compact, ir[SSAValue(use)]::Expr, typ)
818+
field = try_compute_fieldidx_stmt(ir, ir[SSAValue(use)]::Expr, typ)
815819
field === nothing && @goto skip
816820
push!(fielddefuse[field].defs, use)
817821
end
@@ -846,8 +850,9 @@ function sroa_pass!(ir::IRCode)
846850
end
847851
end
848852
end
849-
preserve_uses = IdDict{Int, Vector{Any}}((idx=>Any[] for idx in IdSet{Int}(defuse.ccall_preserve_uses)))
850853
# Everything accounted for. Go field by field and perform idf
854+
preserve_uses = isempty(defuse.ccall_preserve_uses) ? nothing :
855+
IdDict{Int, Vector{Any}}((idx=>Any[] for idx in BitSet(defuse.ccall_preserve_uses)))
851856
for fidx in 1:ndefuse
852857
du = fielddefuse[fidx]
853858
ftyp = fieldtype(typ, fidx)
@@ -863,8 +868,10 @@ function sroa_pass!(ir::IRCode)
863868
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
864869
end
865870
if !isbitstype(ftyp)
866-
for (use, list) in preserve_uses
867-
push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
871+
if preserve_uses !== nothing
872+
for (use, list) in preserve_uses
873+
push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
874+
end
868875
end
869876
end
870877
for b in phiblocks
@@ -881,7 +888,7 @@ function sroa_pass!(ir::IRCode)
881888
ir[SSAValue(stmt)] = nothing
882889
end
883890
end
884-
isempty(defuse.ccall_preserve_uses) && continue
891+
preserve_uses === nothing && continue
885892
push!(intermediaries, newidx)
886893
# Insert the new preserves
887894
for (use, new_preserves) in preserve_uses
@@ -897,10 +904,7 @@ function sroa_pass!(ir::IRCode)
897904

898905
@label skip
899906
end
900-
901-
return ir
902907
end
903-
# assertion_counter = 0
904908

905909
"""
906910
canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::Expr)

0 commit comments

Comments
 (0)