Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b8af9c6

Browse files
committedNov 26, 2021
optimizer: refactor SROA pass
- use `BitSet` instead of `IdSet{Int}` - reduce # of dynamic allocations - separate some computations into individual functions
1 parent df0080c commit b8af9c6

File tree

1 file changed

+87
-79
lines changed

1 file changed

+87
-79
lines changed
 

‎base/compiler/ssair/passes.jl

+87-79
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])
2323

2424
compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses)
2525

26-
function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
26+
function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr)
2727
field = stmt.args[3]
2828
# fields are usually literals, handle them manually
2929
if isa(field, QuoteNode)
3030
field = field.value
3131
elseif isa(field, Int)
3232
# try to resolve other constants, e.g. global reference
3333
else
34-
field = compact_exprtype(compact, field)
34+
field = isa(ir, IncrementalCompact) ? compact_exprtype(ir, field) : argextype(field, ir)
3535
if isa(field, Const)
3636
field = field.val
3737
else
@@ -42,8 +42,8 @@ function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
4242
return field
4343
end
4444

45-
function try_compute_fieldidx_stmt(compact::IncrementalCompact, stmt::Expr, typ::DataType)
46-
field = try_compute_field_stmt(compact, stmt)
45+
function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType)
46+
field = try_compute_field_stmt(ir, stmt)
4747
return try_compute_fieldidx(typ, field)
4848
end
4949

@@ -112,6 +112,13 @@ function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
112112
return def, stmtblock, curblock
113113
end
114114

115+
function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint))
116+
if isa(val, Union{OldSSAValue, SSAValue})
117+
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
118+
end
119+
return walk_to_defs(compact, val, typeconstraint)
120+
end
121+
115122
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
116123
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
117124
while true
@@ -152,7 +159,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
152159
end
153160

154161
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
155-
@nospecialize(typeconstraint = types(compact)[defssa]))
162+
@nospecialize(typeconstraint))
156163
callback = function (@nospecialize(pi), @nospecialize(idx))
157164
if isa(pi, PiNode)
158165
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
@@ -164,20 +171,16 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss
164171
end
165172

166173
"""
167-
walk_to_defs(compact, val, intermediaries)
174+
walk_to_defs(compact, val, typeconstraint)
168175
169-
Starting at `val` walk use-def chains to get all the leaves feeding into
170-
this val (pruning those leaves rules out by path conditions).
176+
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
177+
(pruning those leaves rules out by path conditions).
171178
"""
172-
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint), visited_phinodes::Vector{AnySSAValue}=AnySSAValue[])
173-
isa(defssa, AnySSAValue) || return Any[defssa]
179+
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
180+
visited_phinodes = AnySSAValue[]
181+
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
174182
def = compact[defssa]
175-
isa(def, PhiNode) || return Any[defssa]
176-
# Step 2: Figure out what the struct is defined as
177-
## Track definitions through PiNode/PhiNode
178-
found_def = false
179-
## Track which PhiNodes, SSAValue intermediaries
180-
## we forwarded through.
183+
isa(def, PhiNode) || return Any[defssa], visited_phinodes
181184
visited_constraints = IdDict{AnySSAValue, Any}()
182185
worklist_defs = AnySSAValue[]
183186
worklist_constraints = Any[]
@@ -239,10 +242,10 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
239242
push!(leaves, defssa)
240243
end
241244
end
242-
leaves
245+
return leaves, visited_phinodes
243246
end
244247

245-
function process_immutable_preserve(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr)
248+
function process_immutable_preserve!(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr)
246249
for arg in (isexpr(def, :new) ? def.args : def.args[2:end])
247250
if !isbitstype(widenconst(compact_exprtype(compact, arg)))
248251
push!(new_preserves, arg)
@@ -449,13 +452,10 @@ function lift_comparison!(compact::IncrementalCompact,
449452
return
450453
end
451454

452-
if isa(val, Union{OldSSAValue, SSAValue})
453-
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
454-
end
455-
456-
visited_phinodes = AnySSAValue[]
457-
leaves = walk_to_defs(compact, val, typeconstraint, visited_phinodes)
455+
valtyp = widenconst(compact_exprtype(compact, val))
456+
isa(valtyp, Union) || return # bail out if there won't be a good chance for lifting
458457

458+
leaves, visited_phinodes = collect_leaves(compact, val, valtyp)
459459
length(leaves) 1 && return # bail out if we don't have multiple leaves
460460

461461
# Let's check if we evaluate the comparison for each one of the leaves
@@ -476,10 +476,6 @@ function lift_comparison!(compact::IncrementalCompact,
476476
visited_phinodes, cmp, lifting_cache, Bool,
477477
lifted_leaves::IdDict{Any, Union{Nothing,LiftedValue}}, val)::LiftedValue
478478

479-
# global assertion_counter
480-
# assertion_counter::Int += 1
481-
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true)
482-
# return
483479
compact[idx] = lifted_val.x
484480
end
485481

@@ -576,6 +572,10 @@ function perform_lifting!(compact::IncrementalCompact,
576572
return stmt_val # N.B. should never happen
577573
end
578574

575+
# NOTE we use `IdSet{Int}` instead of `BitSet` for `sroa_pass!` since it works on IR after inlining,
576+
# which can be very large sometimes, and analyzed program counters are often very sparse
577+
const SPCSet = IdSet{Int}
578+
579579
"""
580580
sroa_pass!(ir::IRCode) -> newir::IRCode
581581
@@ -596,17 +596,16 @@ a result of succeeding dead code elimination.
596596
"""
597597
function sroa_pass!(ir::IRCode)
598598
compact = IncrementalCompact(ir)
599-
defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}()
599+
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
600600
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
601601
for ((_, idx), stmt) in compact
602+
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
602603
isa(stmt, Expr) || continue
603-
result_t = compact_exprtype(compact, SSAValue(idx))
604604
is_setfield = false
605605
field_ordering = :unspecified
606-
# Step 1: Check whether the statement we're looking at is a getfield/setfield!
607606
if is_known_call(stmt, setfield!, compact)
608-
is_setfield = true
609607
4 <= length(stmt.args) <= 5 || continue
608+
is_setfield = true
610609
if length(stmt.args) == 5
611610
field_ordering = compact_exprtype(compact, stmt.args[5])
612611
end
@@ -624,7 +623,7 @@ function sroa_pass!(ir::IRCode)
624623
old_preserves = stmt.args[(6+nccallargs):end]
625624
for (pidx, preserved_arg) in enumerate(old_preserves)
626625
isa(preserved_arg, SSAValue) || continue
627-
let intermediaries = IdSet{Int}()
626+
let intermediaries = SPCSet()
628627
callback = function (@nospecialize(pi), @nospecialize(ssa))
629628
push!(intermediaries, ssa.id)
630629
return false
@@ -634,7 +633,7 @@ function sroa_pass!(ir::IRCode)
634633
defidx = def.id
635634
def = compact[defidx]
636635
if is_tuple_call(compact, def)
637-
process_immutable_preserve(new_preserves, compact, def)
636+
process_immutable_preserve!(new_preserves, compact, def)
638637
old_preserves[pidx] = nothing
639638
continue
640639
elseif isexpr(def, :new)
@@ -643,14 +642,17 @@ function sroa_pass!(ir::IRCode)
643642
typ = unwrap_unionall(typ)
644643
end
645644
if typ isa DataType && !ismutabletype(typ)
646-
process_immutable_preserve(new_preserves, compact, def)
645+
process_immutable_preserve!(new_preserves, compact, def)
647646
old_preserves[pidx] = nothing
648647
continue
649648
end
650649
else
651650
continue
652651
end
653-
mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse()))
652+
if defuses === nothing
653+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
654+
end
655+
mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse()))
654656
push!(defuse.ccall_preserve_uses, idx)
655657
union!(mid, intermediaries)
656658
end
@@ -675,6 +677,9 @@ function sroa_pass!(ir::IRCode)
675677
else
676678
continue
677679
end
680+
681+
# analyze this `getfield` / `setfield!` call
682+
678683
field = try_compute_field_stmt(compact, stmt)
679684
field === nothing && continue
680685

@@ -689,19 +694,23 @@ function sroa_pass!(ir::IRCode)
689694
continue
690695
end
691696

692-
def, typeconstraint = stmt.args[2], struct_typ
697+
val = stmt.args[2]
693698

699+
# analyze this mutable struct here for the later pass
694700
if ismutabletype(struct_typ)
695-
isa(def, SSAValue) || continue
696-
let intermediaries = IdSet{Int}()
701+
isa(val, SSAValue) || continue
702+
let intermediaries = SPCSet()
697703
callback = function (@nospecialize(pi), @nospecialize(ssa))
698704
push!(intermediaries, ssa.id)
699705
return false
700706
end
701-
def = simple_walk(compact, def, callback)
707+
def = simple_walk(compact, val, callback)
702708
# Mutable stuff here
703709
isa(def, SSAValue) || continue
704-
mid, defuse = get!(defuses, def.id, (IdSet{Int}(), SSADefUse()))
710+
if defuses === nothing
711+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
712+
end
713+
mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse()))
705714
if is_setfield
706715
push!(defuse.defs, idx)
707716
else
@@ -711,26 +720,21 @@ function sroa_pass!(ir::IRCode)
711720
end
712721
continue
713722
elseif is_setfield
714-
continue
723+
continue # invalid `setfield!` call, but just ignore here
715724
end
716725

717726
# perform SROA on immutable structs here on
718727

719-
if isa(def, Union{OldSSAValue, SSAValue})
720-
def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint)
721-
end
722-
723-
visited_phinodes = AnySSAValue[]
724-
leaves = walk_to_defs(compact, def, typeconstraint, visited_phinodes)
725-
726-
isempty(leaves) && continue
727-
728728
field = try_compute_fieldidx(struct_typ, field)
729729
field === nothing && continue
730730

731-
r = lift_leaves(compact, result_t, field, leaves)
732-
r === nothing && continue
733-
lifted_leaves, any_undef = r
731+
leaves, visited_phinodes = collect_leaves(compact, val, struct_typ)
732+
isempty(leaves) && continue
733+
734+
result_t = compact_exprtype(compact, SSAValue(idx))
735+
lifted_result = lift_leaves(compact, result_t, field, leaves)
736+
lifted_result === nothing && continue
737+
lifted_leaves, any_undef = lifted_result
734738

735739
if any_undef
736740
result_t = make_MaybeUndef(result_t)
@@ -750,28 +754,32 @@ function sroa_pass!(ir::IRCode)
750754
@assert val !== nothing
751755
end
752756

753-
# global assertion_counter
754-
# assertion_counter::Int += 1
755-
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
756-
# continue
757757
compact[idx] = val === nothing ? nothing : val.x
758758
end
759759

760760
non_dce_finish!(compact)
761-
# Copy the use count, `simple_dce!` may modify it and for our predicate
762-
# below we need it consistent with the state of the IR here (after tracking
763-
# phi node arguments, but before dce).
764-
used_ssas = copy(compact.used_ssas)
765-
simple_dce!(compact)
766-
ir = complete(compact)
767-
768-
# Compute domtree, needed below, now that we have finished compacting the
769-
# IR. This needs to be after we iterate through the IR with
770-
# `IncrementalCompact` because removing dead blocks can invalidate the
771-
# domtree.
761+
if defuses !== nothing
762+
# now go through analyzed mutable structs and see which ones we can eliminate
763+
# NOTE copy the use count here, because `simple_dce!` may modify it and we need it
764+
# consistent with the state of the IR here (after tracking `PhiNode` arguments,
765+
# but before the DCE) for our predicate within `sroa_mutables!`
766+
used_ssas = copy(compact.used_ssas)
767+
simple_dce!(compact)
768+
ir = complete(compact)
769+
sroa_mutables!(ir, defuses, used_ssas)
770+
return ir
771+
else
772+
simple_dce!(compact)
773+
return complete(compact)
774+
end
775+
end
776+
777+
function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int})
778+
# Compute domtree, needed below, now that we have finished compacting the IR.
779+
# This needs to be after we iterate through the IR with `IncrementalCompact`
780+
# because removing dead blocks can invalidate the domtree.
772781
@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)
773782

774-
# Now go through any mutable structs and see which ones we can eliminate
775783
for (idx, (intermediaries, defuse)) in defuses
776784
intermediaries = collect(intermediaries)
777785
# Check if there are any uses we did not account for. If so, the variable
@@ -806,12 +814,12 @@ function sroa_pass!(ir::IRCode)
806814
# it would have been deleted. That's fine, just ignore
807815
# the use in that case.
808816
stmt === nothing && continue
809-
field = try_compute_fieldidx_stmt(compact, stmt::Expr, typ)
817+
field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ)
810818
field === nothing && @goto skip
811819
push!(fielddefuse[field].uses, use)
812820
end
813821
for use in defuse.defs
814-
field = try_compute_fieldidx_stmt(compact, ir[SSAValue(use)]::Expr, typ)
822+
field = try_compute_fieldidx_stmt(ir, ir[SSAValue(use)]::Expr, typ)
815823
field === nothing && @goto skip
816824
push!(fielddefuse[field].defs, use)
817825
end
@@ -846,8 +854,9 @@ function sroa_pass!(ir::IRCode)
846854
end
847855
end
848856
end
849-
preserve_uses = IdDict{Int, Vector{Any}}((idx=>Any[] for idx in IdSet{Int}(defuse.ccall_preserve_uses)))
850857
# Everything accounted for. Go field by field and perform idf
858+
preserve_uses = isempty(defuse.ccall_preserve_uses) ? nothing :
859+
IdDict{Int, Vector{Any}}((idx=>Any[] for idx in SPCSet(defuse.ccall_preserve_uses)))
851860
for fidx in 1:ndefuse
852861
du = fielddefuse[fidx]
853862
ftyp = fieldtype(typ, fidx)
@@ -863,8 +872,10 @@ function sroa_pass!(ir::IRCode)
863872
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
864873
end
865874
if !isbitstype(ftyp)
866-
for (use, list) in preserve_uses
867-
push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
875+
if preserve_uses !== nothing
876+
for (use, list) in preserve_uses
877+
push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
878+
end
868879
end
869880
end
870881
for b in phiblocks
@@ -881,7 +892,7 @@ function sroa_pass!(ir::IRCode)
881892
ir[SSAValue(stmt)] = nothing
882893
end
883894
end
884-
isempty(defuse.ccall_preserve_uses) && continue
895+
preserve_uses === nothing && continue
885896
push!(intermediaries, newidx)
886897
# Insert the new preserves
887898
for (use, new_preserves) in preserve_uses
@@ -897,10 +908,7 @@ function sroa_pass!(ir::IRCode)
897908

898909
@label skip
899910
end
900-
901-
return ir
902911
end
903-
# assertion_counter = 0
904912

905913
"""
906914
canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::Expr)

0 commit comments

Comments
 (0)
Please sign in to comment.