@@ -23,15 +23,15 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])
23
23
24
24
compute_live_ins (cfg:: CFG , du:: SSADefUse ) = compute_live_ins (cfg, du. defs, du. uses)
25
25
26
- function try_compute_field_stmt (compact :: IncrementalCompact , stmt:: Expr )
26
+ function try_compute_field_stmt (ir :: Union{ IncrementalCompact,IRCode} , stmt:: Expr )
27
27
field = stmt. args[3 ]
28
28
# fields are usually literals, handle them manually
29
29
if isa (field, QuoteNode)
30
30
field = field. value
31
31
elseif isa (field, Int)
32
32
# try to resolve other constants, e.g. global reference
33
33
else
34
- field = compact_exprtype (compact , field)
34
+ field = isa (ir, IncrementalCompact) ? compact_exprtype (ir , field) : argextype (field, ir )
35
35
if isa (field, Const)
36
36
field = field. val
37
37
else
@@ -42,8 +42,8 @@ function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
42
42
return field
43
43
end
44
44
45
- function try_compute_fieldidx_stmt (compact :: IncrementalCompact , stmt:: Expr , typ:: DataType )
46
- field = try_compute_field_stmt (compact , stmt)
45
+ function try_compute_fieldidx_stmt (ir :: Union{ IncrementalCompact,IRCode} , stmt:: Expr , typ:: DataType )
46
+ field = try_compute_field_stmt (ir , stmt)
47
47
return try_compute_fieldidx (typ, field)
48
48
end
49
49
@@ -112,6 +112,13 @@ function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
112
112
return def, stmtblock, curblock
113
113
end
114
114
115
+ function collect_leaves (compact:: IncrementalCompact , @nospecialize (val), @nospecialize (typeconstraint))
116
+ if isa (val, Union{OldSSAValue, SSAValue})
117
+ val, typeconstraint = simple_walk_constraint (compact, val, typeconstraint)
118
+ end
119
+ return walk_to_defs (compact, val, typeconstraint)
120
+ end
121
+
115
122
function simple_walk (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
116
123
callback = (@nospecialize (pi ), @nospecialize (idx)) -> false )
117
124
while true
@@ -152,7 +159,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
152
159
end
153
160
154
161
function simple_walk_constraint (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
155
- @nospecialize (typeconstraint = types (compact)[defssa] ))
162
+ @nospecialize (typeconstraint))
156
163
callback = function (@nospecialize (pi ), @nospecialize (idx))
157
164
if isa (pi , PiNode)
158
165
typeconstraint = typeintersect (typeconstraint, widenconst (pi . typ))
@@ -164,20 +171,16 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss
164
171
end
165
172
166
173
"""
167
- walk_to_defs(compact, val, intermediaries )
174
+ walk_to_defs(compact, val, typeconstraint )
168
175
169
- Starting at `val` walk use-def chains to get all the leaves feeding into
170
- this val (pruning those leaves rules out by path conditions).
176
+ Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
177
+ (pruning those leaves rules out by path conditions).
171
178
"""
172
- function walk_to_defs (compact:: IncrementalCompact , @nospecialize (defssa), @nospecialize (typeconstraint), visited_phinodes:: Vector{AnySSAValue} = AnySSAValue[])
173
- isa (defssa, AnySSAValue) || return Any[defssa]
179
+ function walk_to_defs (compact:: IncrementalCompact , @nospecialize (defssa), @nospecialize (typeconstraint))
180
+ visited_phinodes = AnySSAValue[]
181
+ isa (defssa, AnySSAValue) || return Any[defssa], visited_phinodes
174
182
def = compact[defssa]
175
- isa (def, PhiNode) || return Any[defssa]
176
- # Step 2: Figure out what the struct is defined as
177
- # # Track definitions through PiNode/PhiNode
178
- found_def = false
179
- # # Track which PhiNodes, SSAValue intermediaries
180
- # # we forwarded through.
183
+ isa (def, PhiNode) || return Any[defssa], visited_phinodes
181
184
visited_constraints = IdDict {AnySSAValue, Any} ()
182
185
worklist_defs = AnySSAValue[]
183
186
worklist_constraints = Any[]
@@ -239,10 +242,10 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
239
242
push! (leaves, defssa)
240
243
end
241
244
end
242
- leaves
245
+ return leaves, visited_phinodes
243
246
end
244
247
245
- function process_immutable_preserve (new_preserves:: Vector{Any} , compact:: IncrementalCompact , def:: Expr )
248
+ function process_immutable_preserve! (new_preserves:: Vector{Any} , compact:: IncrementalCompact , def:: Expr )
246
249
for arg in (isexpr (def, :new ) ? def. args : def. args[2 : end ])
247
250
if ! isbitstype (widenconst (compact_exprtype (compact, arg)))
248
251
push! (new_preserves, arg)
@@ -449,13 +452,10 @@ function lift_comparison!(compact::IncrementalCompact,
449
452
return
450
453
end
451
454
452
- if isa (val, Union{OldSSAValue, SSAValue})
453
- val, typeconstraint = simple_walk_constraint (compact, val, typeconstraint)
454
- end
455
-
456
- visited_phinodes = AnySSAValue[]
457
- leaves = walk_to_defs (compact, val, typeconstraint, visited_phinodes)
455
+ valtyp = widenconst (compact_exprtype (compact, val))
456
+ isa (valtyp, Union) || return # bail out if there won't be a good chance for lifting
458
457
458
+ leaves, visited_phinodes = collect_leaves (compact, val, valtyp)
459
459
length (leaves) ≤ 1 && return # bail out if we don't have multiple leaves
460
460
461
461
# Let's check if we evaluate the comparison for each one of the leaves
@@ -476,10 +476,6 @@ function lift_comparison!(compact::IncrementalCompact,
476
476
visited_phinodes, cmp, lifting_cache, Bool,
477
477
lifted_leaves:: IdDict{Any, Union{Nothing,LiftedValue}} , val):: LiftedValue
478
478
479
- # global assertion_counter
480
- # assertion_counter::Int += 1
481
- # insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true)
482
- # return
483
479
compact[idx] = lifted_val. x
484
480
end
485
481
@@ -576,6 +572,10 @@ function perform_lifting!(compact::IncrementalCompact,
576
572
return stmt_val # N.B. should never happen
577
573
end
578
574
575
+ # NOTE we use `IdSet{Int}` instead of `BitSet` for `sroa_pass!` since it works on IR after inlining,
576
+ # which can be very large sometimes, and analyzed program counters are often very sparse
577
+ const SPCSet = IdSet{Int}
578
+
579
579
"""
580
580
sroa_pass!(ir::IRCode) -> newir::IRCode
581
581
@@ -596,17 +596,16 @@ a result of succeeding dead code elimination.
596
596
"""
597
597
function sroa_pass! (ir:: IRCode )
598
598
compact = IncrementalCompact (ir)
599
- defuses = IdDict {Int, Tuple{IdSet{Int}, SSADefUse}} ()
599
+ defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
600
600
lifting_cache = IdDict {Pair{AnySSAValue, Any}, AnySSAValue} ()
601
601
for ((_, idx), stmt) in compact
602
+ # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
602
603
isa (stmt, Expr) || continue
603
- result_t = compact_exprtype (compact, SSAValue (idx))
604
604
is_setfield = false
605
605
field_ordering = :unspecified
606
- # Step 1: Check whether the statement we're looking at is a getfield/setfield!
607
606
if is_known_call (stmt, setfield!, compact)
608
- is_setfield = true
609
607
4 <= length (stmt. args) <= 5 || continue
608
+ is_setfield = true
610
609
if length (stmt. args) == 5
611
610
field_ordering = compact_exprtype (compact, stmt. args[5 ])
612
611
end
@@ -624,7 +623,7 @@ function sroa_pass!(ir::IRCode)
624
623
old_preserves = stmt. args[(6 + nccallargs): end ]
625
624
for (pidx, preserved_arg) in enumerate (old_preserves)
626
625
isa (preserved_arg, SSAValue) || continue
627
- let intermediaries = IdSet {Int} ()
626
+ let intermediaries = SPCSet ()
628
627
callback = function (@nospecialize (pi ), @nospecialize (ssa))
629
628
push! (intermediaries, ssa. id)
630
629
return false
@@ -634,7 +633,7 @@ function sroa_pass!(ir::IRCode)
634
633
defidx = def. id
635
634
def = compact[defidx]
636
635
if is_tuple_call (compact, def)
637
- process_immutable_preserve (new_preserves, compact, def)
636
+ process_immutable_preserve! (new_preserves, compact, def)
638
637
old_preserves[pidx] = nothing
639
638
continue
640
639
elseif isexpr (def, :new )
@@ -643,14 +642,17 @@ function sroa_pass!(ir::IRCode)
643
642
typ = unwrap_unionall (typ)
644
643
end
645
644
if typ isa DataType && ! ismutabletype (typ)
646
- process_immutable_preserve (new_preserves, compact, def)
645
+ process_immutable_preserve! (new_preserves, compact, def)
647
646
old_preserves[pidx] = nothing
648
647
continue
649
648
end
650
649
else
651
650
continue
652
651
end
653
- mid, defuse = get! (defuses, defidx, (IdSet {Int} (), SSADefUse ()))
652
+ if defuses === nothing
653
+ defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
654
+ end
655
+ mid, defuse = get! (defuses, defidx, (SPCSet (), SSADefUse ()))
654
656
push! (defuse. ccall_preserve_uses, idx)
655
657
union! (mid, intermediaries)
656
658
end
@@ -675,6 +677,9 @@ function sroa_pass!(ir::IRCode)
675
677
else
676
678
continue
677
679
end
680
+
681
+ # analyze this `getfield` / `setfield!` call
682
+
678
683
field = try_compute_field_stmt (compact, stmt)
679
684
field === nothing && continue
680
685
@@ -689,19 +694,23 @@ function sroa_pass!(ir::IRCode)
689
694
continue
690
695
end
691
696
692
- def, typeconstraint = stmt. args[2 ], struct_typ
697
+ val = stmt. args[2 ]
693
698
699
+ # analyze this mutable struct here for the later pass
694
700
if ismutabletype (struct_typ)
695
- isa (def , SSAValue) || continue
696
- let intermediaries = IdSet {Int} ()
701
+ isa (val , SSAValue) || continue
702
+ let intermediaries = SPCSet ()
697
703
callback = function (@nospecialize (pi ), @nospecialize (ssa))
698
704
push! (intermediaries, ssa. id)
699
705
return false
700
706
end
701
- def = simple_walk (compact, def , callback)
707
+ def = simple_walk (compact, val , callback)
702
708
# Mutable stuff here
703
709
isa (def, SSAValue) || continue
704
- mid, defuse = get! (defuses, def. id, (IdSet {Int} (), SSADefUse ()))
710
+ if defuses === nothing
711
+ defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
712
+ end
713
+ mid, defuse = get! (defuses, def. id, (SPCSet (), SSADefUse ()))
705
714
if is_setfield
706
715
push! (defuse. defs, idx)
707
716
else
@@ -711,26 +720,21 @@ function sroa_pass!(ir::IRCode)
711
720
end
712
721
continue
713
722
elseif is_setfield
714
- continue
723
+ continue # invalid `setfield!` call, but just ignore here
715
724
end
716
725
717
726
# perform SROA on immutable structs here on
718
727
719
- if isa (def, Union{OldSSAValue, SSAValue})
720
- def, typeconstraint = simple_walk_constraint (compact, def, typeconstraint)
721
- end
722
-
723
- visited_phinodes = AnySSAValue[]
724
- leaves = walk_to_defs (compact, def, typeconstraint, visited_phinodes)
725
-
726
- isempty (leaves) && continue
727
-
728
728
field = try_compute_fieldidx (struct_typ, field)
729
729
field === nothing && continue
730
730
731
- r = lift_leaves (compact, result_t, field, leaves)
732
- r === nothing && continue
733
- lifted_leaves, any_undef = r
731
+ leaves, visited_phinodes = collect_leaves (compact, val, struct_typ)
732
+ isempty (leaves) && continue
733
+
734
+ result_t = compact_exprtype (compact, SSAValue (idx))
735
+ lifted_result = lift_leaves (compact, result_t, field, leaves)
736
+ lifted_result === nothing && continue
737
+ lifted_leaves, any_undef = lifted_result
734
738
735
739
if any_undef
736
740
result_t = make_MaybeUndef (result_t)
@@ -750,28 +754,32 @@ function sroa_pass!(ir::IRCode)
750
754
@assert val != = nothing
751
755
end
752
756
753
- # global assertion_counter
754
- # assertion_counter::Int += 1
755
- # insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
756
- # continue
757
757
compact[idx] = val === nothing ? nothing : val. x
758
758
end
759
759
760
760
non_dce_finish! (compact)
761
- # Copy the use count, `simple_dce!` may modify it and for our predicate
762
- # below we need it consistent with the state of the IR here (after tracking
763
- # phi node arguments, but before dce).
764
- used_ssas = copy (compact. used_ssas)
765
- simple_dce! (compact)
766
- ir = complete (compact)
767
-
768
- # Compute domtree, needed below, now that we have finished compacting the
769
- # IR. This needs to be after we iterate through the IR with
770
- # `IncrementalCompact` because removing dead blocks can invalidate the
771
- # domtree.
761
+ if defuses != = nothing
762
+ # now go through analyzed mutable structs and see which ones we can eliminate
763
+ # NOTE copy the use count here, because `simple_dce!` may modify it and we need it
764
+ # consistent with the state of the IR here (after tracking `PhiNode` arguments,
765
+ # but before the DCE) for our predicate within `sroa_mutables!`
766
+ used_ssas = copy (compact. used_ssas)
767
+ simple_dce! (compact)
768
+ ir = complete (compact)
769
+ sroa_mutables! (ir, defuses, used_ssas)
770
+ return ir
771
+ else
772
+ simple_dce! (compact)
773
+ return complete (compact)
774
+ end
775
+ end
776
+
777
+ function sroa_mutables! (ir:: IRCode , defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} )
778
+ # Compute domtree, needed below, now that we have finished compacting the IR.
779
+ # This needs to be after we iterate through the IR with `IncrementalCompact`
780
+ # because removing dead blocks can invalidate the domtree.
772
781
@timeit " domtree 2" domtree = construct_domtree (ir. cfg. blocks)
773
782
774
- # Now go through any mutable structs and see which ones we can eliminate
775
783
for (idx, (intermediaries, defuse)) in defuses
776
784
intermediaries = collect (intermediaries)
777
785
# Check if there are any uses we did not account for. If so, the variable
@@ -806,12 +814,12 @@ function sroa_pass!(ir::IRCode)
806
814
# it would have been deleted. That's fine, just ignore
807
815
# the use in that case.
808
816
stmt === nothing && continue
809
- field = try_compute_fieldidx_stmt (compact , stmt:: Expr , typ)
817
+ field = try_compute_fieldidx_stmt (ir , stmt:: Expr , typ)
810
818
field === nothing && @goto skip
811
819
push! (fielddefuse[field]. uses, use)
812
820
end
813
821
for use in defuse. defs
814
- field = try_compute_fieldidx_stmt (compact , ir[SSAValue (use)]:: Expr , typ)
822
+ field = try_compute_fieldidx_stmt (ir , ir[SSAValue (use)]:: Expr , typ)
815
823
field === nothing && @goto skip
816
824
push! (fielddefuse[field]. defs, use)
817
825
end
@@ -846,8 +854,9 @@ function sroa_pass!(ir::IRCode)
846
854
end
847
855
end
848
856
end
849
- preserve_uses = IdDict {Int, Vector{Any}} ((idx=> Any[] for idx in IdSet {Int} (defuse. ccall_preserve_uses)))
850
857
# Everything accounted for. Go field by field and perform idf
858
+ preserve_uses = isempty (defuse. ccall_preserve_uses) ? nothing :
859
+ IdDict {Int, Vector{Any}} ((idx=> Any[] for idx in SPCSet (defuse. ccall_preserve_uses)))
851
860
for fidx in 1 : ndefuse
852
861
du = fielddefuse[fidx]
853
862
ftyp = fieldtype (typ, fidx)
@@ -863,8 +872,10 @@ function sroa_pass!(ir::IRCode)
863
872
ir[SSAValue (stmt)] = compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, stmt)
864
873
end
865
874
if ! isbitstype (ftyp)
866
- for (use, list) in preserve_uses
867
- push! (list, compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, use))
875
+ if preserve_uses != = nothing
876
+ for (use, list) in preserve_uses
877
+ push! (list, compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, use))
878
+ end
868
879
end
869
880
end
870
881
for b in phiblocks
@@ -881,7 +892,7 @@ function sroa_pass!(ir::IRCode)
881
892
ir[SSAValue (stmt)] = nothing
882
893
end
883
894
end
884
- isempty (defuse . ccall_preserve_uses) && continue
895
+ preserve_uses === nothing && continue
885
896
push! (intermediaries, newidx)
886
897
# Insert the new preserves
887
898
for (use, new_preserves) in preserve_uses
@@ -897,10 +908,7 @@ function sroa_pass!(ir::IRCode)
897
908
898
909
@label skip
899
910
end
900
-
901
- return ir
902
911
end
903
- # assertion_counter = 0
904
912
905
913
"""
906
914
canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::Expr)
0 commit comments