@@ -23,15 +23,15 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])
23
23
24
24
compute_live_ins (cfg:: CFG , du:: SSADefUse ) = compute_live_ins (cfg, du. defs, du. uses)
25
25
26
- function try_compute_field_stmt (compact :: IncrementalCompact , stmt:: Expr )
26
+ function try_compute_field_stmt (ir :: Union{ IncrementalCompact,IRCode} , stmt:: Expr )
27
27
field = stmt. args[3 ]
28
28
# fields are usually literals, handle them manually
29
29
if isa (field, QuoteNode)
30
30
field = field. value
31
31
elseif isa (field, Int)
32
32
# try to resolve other constants, e.g. global reference
33
33
else
34
- field = compact_exprtype (compact , field)
34
+ field = isa (ir, IncrementalCompact) ? compact_exprtype (ir , field) : argextype (field, ir )
35
35
if isa (field, Const)
36
36
field = field. val
37
37
else
@@ -42,8 +42,8 @@ function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
42
42
return field
43
43
end
44
44
45
- function try_compute_fieldidx_stmt (compact :: IncrementalCompact , stmt:: Expr , typ:: DataType )
46
- field = try_compute_field_stmt (compact , stmt)
45
+ function try_compute_fieldidx_stmt (ir :: Union{ IncrementalCompact,IRCode} , stmt:: Expr , typ:: DataType )
46
+ field = try_compute_field_stmt (ir , stmt)
47
47
return try_compute_fieldidx (typ, field)
48
48
end
49
49
@@ -112,6 +112,13 @@ function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
112
112
return def, stmtblock, curblock
113
113
end
114
114
115
+ function collect_leaves (compact:: IncrementalCompact , @nospecialize (val), @nospecialize (typeconstraint))
116
+ if isa (val, Union{OldSSAValue, SSAValue})
117
+ val, typeconstraint = simple_walk_constraint (compact, val, typeconstraint)
118
+ end
119
+ return walk_to_defs (compact, val, typeconstraint)
120
+ end
121
+
115
122
function simple_walk (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
116
123
callback = (@nospecialize (pi ), @nospecialize (idx)) -> false )
117
124
while true
@@ -152,7 +159,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
152
159
end
153
160
154
161
function simple_walk_constraint (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
155
- @nospecialize (typeconstraint = types (compact)[defssa] ))
162
+ @nospecialize (typeconstraint))
156
163
callback = function (@nospecialize (pi ), @nospecialize (idx))
157
164
if isa (pi , PiNode)
158
165
typeconstraint = typeintersect (typeconstraint, widenconst (pi . typ))
@@ -164,20 +171,16 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss
164
171
end
165
172
166
173
"""
167
- walk_to_defs(compact, val, intermediaries )
174
+ walk_to_defs(compact, val, typeconstraint )
168
175
169
- Starting at `val` walk use-def chains to get all the leaves feeding into
170
- this val (pruning those leaves rules out by path conditions).
176
+ Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
177
+ (pruning those leaves rules out by path conditions).
171
178
"""
172
- function walk_to_defs (compact:: IncrementalCompact , @nospecialize (defssa), @nospecialize (typeconstraint), visited_phinodes:: Vector{AnySSAValue} = AnySSAValue[])
173
- isa (defssa, AnySSAValue) || return Any[defssa]
179
+ function walk_to_defs (compact:: IncrementalCompact , @nospecialize (defssa), @nospecialize (typeconstraint))
180
+ visited_phinodes = AnySSAValue[]
181
+ isa (defssa, AnySSAValue) || return Any[defssa], visited_phinodes
174
182
def = compact[defssa]
175
- isa (def, PhiNode) || return Any[defssa]
176
- # Step 2: Figure out what the struct is defined as
177
- # # Track definitions through PiNode/PhiNode
178
- found_def = false
179
- # # Track which PhiNodes, SSAValue intermediaries
180
- # # we forwarded through.
183
+ isa (def, PhiNode) || return Any[defssa], visited_phinodes
181
184
visited_constraints = IdDict {AnySSAValue, Any} ()
182
185
worklist_defs = AnySSAValue[]
183
186
worklist_constraints = Any[]
@@ -239,10 +242,10 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
239
242
push! (leaves, defssa)
240
243
end
241
244
end
242
- leaves
245
+ return leaves, visited_phinodes
243
246
end
244
247
245
- function process_immutable_preserve (new_preserves:: Vector{Any} , compact:: IncrementalCompact , def:: Expr )
248
+ function process_immutable_preserve! (new_preserves:: Vector{Any} , compact:: IncrementalCompact , def:: Expr )
246
249
for arg in (isexpr (def, :new ) ? def. args : def. args[2 : end ])
247
250
if ! isbitstype (widenconst (compact_exprtype (compact, arg)))
248
251
push! (new_preserves, arg)
@@ -449,13 +452,10 @@ function lift_comparison!(compact::IncrementalCompact,
449
452
return
450
453
end
451
454
452
- if isa (val, Union{OldSSAValue, SSAValue})
453
- val, typeconstraint = simple_walk_constraint (compact, val, typeconstraint)
454
- end
455
-
456
- visited_phinodes = AnySSAValue[]
457
- leaves = walk_to_defs (compact, val, typeconstraint, visited_phinodes)
455
+ valtyp = widenconst (compact_exprtype (compact, val))
456
+ isa (valtyp, Union) || return # bail out if there won't be a good chance for lifting
458
457
458
+ leaves, visited_phinodes = collect_leaves (compact, val, valtyp)
459
459
length (leaves) ≤ 1 && return # bail out if we don't have multiple leaves
460
460
461
461
# Let's check if we evaluate the comparison for each one of the leaves
@@ -476,10 +476,6 @@ function lift_comparison!(compact::IncrementalCompact,
476
476
visited_phinodes, cmp, lifting_cache, Bool,
477
477
lifted_leaves:: IdDict{Any, Union{Nothing,LiftedValue}} , val):: LiftedValue
478
478
479
- # global assertion_counter
480
- # assertion_counter::Int += 1
481
- # insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true)
482
- # return
483
479
compact[idx] = lifted_val. x
484
480
end
485
481
@@ -596,17 +592,16 @@ a result of succeeding dead code elimination.
596
592
"""
597
593
function sroa_pass! (ir:: IRCode )
598
594
compact = IncrementalCompact (ir)
599
- defuses = IdDict {Int, Tuple{IdSet{Int}, SSADefUse}} ()
595
+ defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
600
596
lifting_cache = IdDict {Pair{AnySSAValue, Any}, AnySSAValue} ()
601
597
for ((_, idx), stmt) in compact
598
+ # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
602
599
isa (stmt, Expr) || continue
603
- result_t = compact_exprtype (compact, SSAValue (idx))
604
600
is_setfield = false
605
601
field_ordering = :unspecified
606
- # Step 1: Check whether the statement we're looking at is a getfield/setfield!
607
602
if is_known_call (stmt, setfield!, compact)
608
- is_setfield = true
609
603
4 <= length (stmt. args) <= 5 || continue
604
+ is_setfield = true
610
605
if length (stmt. args) == 5
611
606
field_ordering = compact_exprtype (compact, stmt. args[5 ])
612
607
end
@@ -624,7 +619,7 @@ function sroa_pass!(ir::IRCode)
624
619
old_preserves = stmt. args[(6 + nccallargs): end ]
625
620
for (pidx, preserved_arg) in enumerate (old_preserves)
626
621
isa (preserved_arg, SSAValue) || continue
627
- let intermediaries = IdSet {Int} ()
622
+ let intermediaries = BitSet ()
628
623
callback = function (@nospecialize (pi ), @nospecialize (ssa))
629
624
push! (intermediaries, ssa. id)
630
625
return false
@@ -634,7 +629,7 @@ function sroa_pass!(ir::IRCode)
634
629
defidx = def. id
635
630
def = compact[defidx]
636
631
if is_tuple_call (compact, def)
637
- process_immutable_preserve (new_preserves, compact, def)
632
+ process_immutable_preserve! (new_preserves, compact, def)
638
633
old_preserves[pidx] = nothing
639
634
continue
640
635
elseif isexpr (def, :new )
@@ -643,14 +638,17 @@ function sroa_pass!(ir::IRCode)
643
638
typ = unwrap_unionall (typ)
644
639
end
645
640
if typ isa DataType && ! ismutabletype (typ)
646
- process_immutable_preserve (new_preserves, compact, def)
641
+ process_immutable_preserve! (new_preserves, compact, def)
647
642
old_preserves[pidx] = nothing
648
643
continue
649
644
end
650
645
else
651
646
continue
652
647
end
653
- mid, defuse = get! (defuses, defidx, (IdSet {Int} (), SSADefUse ()))
648
+ if defuses === nothing
649
+ defuses = IdDict {Int, Tuple{BitSet, SSADefUse}} ()
650
+ end
651
+ mid, defuse = get! (defuses, defidx, (BitSet (), SSADefUse ()))
654
652
push! (defuse. ccall_preserve_uses, idx)
655
653
union! (mid, intermediaries)
656
654
end
@@ -675,6 +673,9 @@ function sroa_pass!(ir::IRCode)
675
673
else
676
674
continue
677
675
end
676
+
677
+ # analyze this `getfield` / `setfield!` call
678
+
678
679
field = try_compute_field_stmt (compact, stmt)
679
680
field === nothing && continue
680
681
@@ -689,19 +690,23 @@ function sroa_pass!(ir::IRCode)
689
690
continue
690
691
end
691
692
692
- def, typeconstraint = stmt. args[2 ], struct_typ
693
+ val = stmt. args[2 ]
693
694
695
+ # analyze this mutable struct here for the later pass
694
696
if ismutabletype (struct_typ)
695
- isa (def , SSAValue) || continue
696
- let intermediaries = IdSet {Int} ()
697
+ isa (val , SSAValue) || continue
698
+ let intermediaries = BitSet ()
697
699
callback = function (@nospecialize (pi ), @nospecialize (ssa))
698
700
push! (intermediaries, ssa. id)
699
701
return false
700
702
end
701
- def = simple_walk (compact, def , callback)
703
+ def = simple_walk (compact, val , callback)
702
704
# Mutable stuff here
703
705
isa (def, SSAValue) || continue
704
- mid, defuse = get! (defuses, def. id, (IdSet {Int} (), SSADefUse ()))
706
+ if defuses === nothing
707
+ defuses = IdDict {Int, Tuple{BitSet, SSADefUse}} ()
708
+ end
709
+ mid, defuse = get! (defuses, def. id, (BitSet (), SSADefUse ()))
705
710
if is_setfield
706
711
push! (defuse. defs, idx)
707
712
else
@@ -711,26 +716,21 @@ function sroa_pass!(ir::IRCode)
711
716
end
712
717
continue
713
718
elseif is_setfield
714
- continue
719
+ continue # invalid `setfield!` call, but just ignore here
715
720
end
716
721
717
722
# perform SROA on immutable structs here on
718
723
719
- if isa (def, Union{OldSSAValue, SSAValue})
720
- def, typeconstraint = simple_walk_constraint (compact, def, typeconstraint)
721
- end
722
-
723
- visited_phinodes = AnySSAValue[]
724
- leaves = walk_to_defs (compact, def, typeconstraint, visited_phinodes)
725
-
726
- isempty (leaves) && continue
727
-
728
724
field = try_compute_fieldidx (struct_typ, field)
729
725
field === nothing && continue
730
726
731
- r = lift_leaves (compact, result_t, field, leaves)
732
- r === nothing && continue
733
- lifted_leaves, any_undef = r
727
+ leaves, visited_phinodes = collect_leaves (compact, val, struct_typ)
728
+ isempty (leaves) && continue
729
+
730
+ result_t = compact_exprtype (compact, SSAValue (idx))
731
+ lifted_result = lift_leaves (compact, result_t, field, leaves)
732
+ lifted_result === nothing && continue
733
+ lifted_leaves, any_undef = lifted_result
734
734
735
735
if any_undef
736
736
result_t = make_MaybeUndef (result_t)
@@ -750,28 +750,32 @@ function sroa_pass!(ir::IRCode)
750
750
@assert val != = nothing
751
751
end
752
752
753
- # global assertion_counter
754
- # assertion_counter::Int += 1
755
- # insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
756
- # continue
757
753
compact[idx] = val === nothing ? nothing : val. x
758
754
end
759
755
760
756
non_dce_finish! (compact)
761
- # Copy the use count, `simple_dce!` may modify it and for our predicate
762
- # below we need it consistent with the state of the IR here (after tracking
763
- # phi node arguments, but before dce).
764
- used_ssas = copy (compact. used_ssas)
765
- simple_dce! (compact)
766
- ir = complete (compact)
767
-
768
- # Compute domtree, needed below, now that we have finished compacting the
769
- # IR. This needs to be after we iterate through the IR with
770
- # `IncrementalCompact` because removing dead blocks can invalidate the
771
- # domtree.
757
+ if defuses != = nothing
758
+ # now go through analyzed mutable structs and see which ones we can eliminate
759
+ # NOTE copy the use count here, because `simple_dce!` may modify it and we need it
760
+ # consistent with the state of the IR here (after tracking `PhiNode` arguments,
761
+ # but before the DCE) for our predicate within `sroa_mutables!`
762
+ used_ssas = copy (compact. used_ssas)
763
+ simple_dce! (compact)
764
+ ir = complete (compact)
765
+ sroa_mutables! (ir, defuses, used_ssas)
766
+ return ir
767
+ else
768
+ simple_dce! (compact)
769
+ return complete (compact)
770
+ end
771
+ end
772
+
773
+ function sroa_mutables! (ir:: IRCode , defuses:: IdDict{Int, Tuple{BitSet, SSADefUse}} , used_ssas:: Vector{Int} )
774
+ # Compute domtree, needed below, now that we have finished compacting the IR.
775
+ # This needs to be after we iterate through the IR with `IncrementalCompact`
776
+ # because removing dead blocks can invalidate the domtree.
772
777
@timeit " domtree 2" domtree = construct_domtree (ir. cfg. blocks)
773
778
774
- # Now go through any mutable structs and see which ones we can eliminate
775
779
for (idx, (intermediaries, defuse)) in defuses
776
780
intermediaries = collect (intermediaries)
777
781
# Check if there are any uses we did not account for. If so, the variable
@@ -806,12 +810,12 @@ function sroa_pass!(ir::IRCode)
806
810
# it would have been deleted. That's fine, just ignore
807
811
# the use in that case.
808
812
stmt === nothing && continue
809
- field = try_compute_fieldidx_stmt (compact , stmt:: Expr , typ)
813
+ field = try_compute_fieldidx_stmt (ir , stmt:: Expr , typ)
810
814
field === nothing && @goto skip
811
815
push! (fielddefuse[field]. uses, use)
812
816
end
813
817
for use in defuse. defs
814
- field = try_compute_fieldidx_stmt (compact , ir[SSAValue (use)]:: Expr , typ)
818
+ field = try_compute_fieldidx_stmt (ir , ir[SSAValue (use)]:: Expr , typ)
815
819
field === nothing && @goto skip
816
820
push! (fielddefuse[field]. defs, use)
817
821
end
@@ -846,8 +850,9 @@ function sroa_pass!(ir::IRCode)
846
850
end
847
851
end
848
852
end
849
- preserve_uses = IdDict {Int, Vector{Any}} ((idx=> Any[] for idx in IdSet {Int} (defuse. ccall_preserve_uses)))
850
853
# Everything accounted for. Go field by field and perform idf
854
+ preserve_uses = isempty (defuse. ccall_preserve_uses) ? nothing :
855
+ IdDict {Int, Vector{Any}} ((idx=> Any[] for idx in BitSet (defuse. ccall_preserve_uses)))
851
856
for fidx in 1 : ndefuse
852
857
du = fielddefuse[fidx]
853
858
ftyp = fieldtype (typ, fidx)
@@ -863,8 +868,10 @@ function sroa_pass!(ir::IRCode)
863
868
ir[SSAValue (stmt)] = compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, stmt)
864
869
end
865
870
if ! isbitstype (ftyp)
866
- for (use, list) in preserve_uses
867
- push! (list, compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, use))
871
+ if preserve_uses != = nothing
872
+ for (use, list) in preserve_uses
873
+ push! (list, compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, use))
874
+ end
868
875
end
869
876
end
870
877
for b in phiblocks
@@ -881,7 +888,7 @@ function sroa_pass!(ir::IRCode)
881
888
ir[SSAValue (stmt)] = nothing
882
889
end
883
890
end
884
- isempty (defuse . ccall_preserve_uses) && continue
891
+ preserve_uses === nothing && continue
885
892
push! (intermediaries, newidx)
886
893
# Insert the new preserves
887
894
for (use, new_preserves) in preserve_uses
@@ -897,10 +904,7 @@ function sroa_pass!(ir::IRCode)
897
904
898
905
@label skip
899
906
end
900
-
901
- return ir
902
907
end
903
- # assertion_counter = 0
904
908
905
909
"""
906
910
canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::Expr)
0 commit comments