@@ -1884,7 +1884,7 @@ void createInvertedTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
1884
1884
LoopContext loopContext;
1885
1885
BasicBlock *BB = cast<BasicBlock>(gutils->getNewFromOriginal (oBB));
1886
1886
bool inLoop = gutils->getContext (BB, loopContext);
1887
- BasicBlock *BB2 = gutils->reverseBlocks [BB];
1887
+ BasicBlock *BB2 = gutils->reverseBlocks [BB]. back () ;
1888
1888
assert (BB2);
1889
1889
IRBuilder<> Builder (BB2);
1890
1890
Builder.setFastMathFlags (getFast ());
@@ -2031,7 +2031,8 @@ void createInvertedTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
2031
2031
handled = true ;
2032
2032
if (!gutils->isConstantValue (oval)) {
2033
2033
BasicBlock *REB =
2034
- gutils->reverseBlocks [*loopContext.exitBlocks .begin ()];
2034
+ gutils->reverseBlocks [*loopContext.exitBlocks .begin ()]
2035
+ .back ();
2035
2036
IRBuilder<> EB (REB);
2036
2037
if (REB->getTerminator ())
2037
2038
EB.SetInsertPoint (REB->getTerminator ());
@@ -2087,7 +2088,8 @@ void createInvertedTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
2087
2088
if (!gutils->isConstantValue (oval)) {
2088
2089
2089
2090
BasicBlock *REB =
2090
- gutils->reverseBlocks [*loopContext.exitBlocks .begin ()];
2091
+ gutils->reverseBlocks [*loopContext.exitBlocks .begin ()]
2092
+ .back ();
2091
2093
IRBuilder<> EB (REB);
2092
2094
if (REB->getTerminator ())
2093
2095
EB.SetInsertPoint (REB->getTerminator ());
@@ -2181,7 +2183,7 @@ void createInvertedTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
2181
2183
pair.second ->replaceAllUsesWith (replaceWith);
2182
2184
pair.second ->eraseFromParent ();
2183
2185
}
2184
-
2186
+ BB2 = gutils-> reverseBlocks [BB]. back ();
2185
2187
Builder.SetInsertPoint (BB2);
2186
2188
2187
2189
Builder.CreateCondBr (
@@ -2211,6 +2213,7 @@ void createInvertedTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
2211
2213
targetToPreds[gutils->getReverseOrLatchMerge (pred, BB)].emplace_back (
2212
2214
std::make_pair (pred, BB));
2213
2215
}
2216
+ BB2 = gutils->reverseBlocks [BB].back ();
2214
2217
Builder.SetInsertPoint (BB2);
2215
2218
gutils->branchToCorrespondingTarget (BB, Builder, targetToPreds);
2216
2219
}
@@ -2580,14 +2583,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2580
2583
assert (orig->getReturnValue ());
2581
2584
assert (differetval);
2582
2585
if (!gutils->isConstantValue (orig->getReturnValue ())) {
2583
- IRBuilder<> reverseB (gutils->reverseBlocks [BB]);
2586
+ IRBuilder<> reverseB (gutils->reverseBlocks [BB]. back () );
2584
2587
gutils->setDiffe (orig->getReturnValue (), differetval, reverseB);
2585
2588
}
2586
2589
} else {
2587
2590
assert (retAlloca == nullptr );
2588
2591
}
2589
2592
2590
- rb.CreateBr (gutils->reverseBlocks [BB]);
2593
+ rb.CreateBr (gutils->reverseBlocks [BB]. front () );
2591
2594
gutils->erase (op);
2592
2595
}
2593
2596
}
@@ -2698,9 +2701,9 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2698
2701
2699
2702
auto Arch =
2700
2703
llvm::Triple (gutils->newFunc ->getParent ()->getTargetTriple ()).getArch ();
2701
- int SharedAddrSpace = Arch == Triple::amdgcn
2702
- ? (int )AMDGPU::HSAMD::AddressSpaceQualifier::Local
2703
- : 3 ;
2704
+ unsigned int SharedAddrSpace =
2705
+ Arch == Triple::amdgcn ? (int )AMDGPU::HSAMD::AddressSpaceQualifier::Local
2706
+ : 3 ;
2704
2707
2705
2708
if (topLevel) {
2706
2709
BasicBlock *sharedBlock = nullptr ;
@@ -2758,8 +2761,9 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2758
2761
2759
2762
IRBuilder<> instbuilder (OldEntryInsts, OldEntryInsts->begin ());
2760
2763
2761
- auto BarrierInst = Arch == Triple::amdgcn ? Intrinsic::amdgcn_s_barrier
2762
- : Intrinsic::nvvm_barrier0;
2764
+ auto BarrierInst = Arch == Triple::amdgcn
2765
+ ? (llvm::Intrinsic::ID)Intrinsic::amdgcn_s_barrier
2766
+ : (llvm::Intrinsic::ID)Intrinsic::nvvm_barrier0;
2763
2767
cast<CallInst>(instbuilder.CreateCall (
2764
2768
Intrinsic::getDeclaration (gutils->newFunc ->getParent (), BarrierInst),
2765
2769
{}));
@@ -2789,9 +2793,9 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2789
2793
(IRBuilder<>(gutils->inversionAllocs )).CreateUnreachable ();
2790
2794
DeleteDeadBlock (gutils->inversionAllocs );
2791
2795
for (auto BBs : gutils->reverseBlocks ) {
2792
- if (pred_begin (BBs.second ) == pred_end (BBs.second )) {
2793
- (IRBuilder<>(BBs.second )).CreateUnreachable ();
2794
- DeleteDeadBlock (BBs.second );
2796
+ if (pred_begin (BBs.second . front ()) == pred_end (BBs.second . front () )) {
2797
+ (IRBuilder<>(BBs.second . front () )).CreateUnreachable ();
2798
+ DeleteDeadBlock (BBs.second . front () );
2795
2799
}
2796
2800
}
2797
2801
0 commit comments