Skip to content

Commit d95e377

Browse files
committed
WIP cachefix (#181)
1 parent cf65343 commit d95e377

File tree

7 files changed

+548
-235
lines changed

7 files changed

+548
-235
lines changed

enzyme/Enzyme/AdjointGenerator.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ class AdjointGenerator
936936
BasicBlock *BB = Builder2.GetInsertBlock();
937937
if (original)
938938
BB = gutils->getNewFromOriginal(BB);
939-
BasicBlock *BB2 = gutils->reverseBlocks[BB];
939+
BasicBlock *BB2 = gutils->reverseBlocks[BB].back();
940940
if (!BB2) {
941941
llvm::errs() << "oldFunc: " << *gutils->oldFunc << "\n";
942942
llvm::errs() << "newFunc: " << *gutils->newFunc << "\n";
@@ -2495,7 +2495,6 @@ class AdjointGenerator
24952495
for (auto pair : geps) {
24962496
Value *op = pair.second;
24972497
Value *alloc = op;
2498-
llvm::errs() << "op: " << *op << "\n";
24992498
Value *replacement = gutils->unwrapM(op, BuilderZ, available,
25002499
UnwrapMode::LegalFullUnwrap);
25012500
tape =
@@ -4319,8 +4318,10 @@ class AdjointGenerator
43194318
eraseIfUnused(*orig, /*erase*/ false, /*check*/ false);
43204319
}
43214320

4322-
for (auto &a : *gutils->reverseBlocks[cast<BasicBlock>(
4323-
gutils->getNewFromOriginal(orig->getParent()))]) {
4321+
for (auto &a : *gutils
4322+
->reverseBlocks[cast<BasicBlock>(
4323+
gutils->getNewFromOriginal(orig->getParent()))]
4324+
.back()) {
43244325
mapp[&a] = &a;
43254326
}
43264327

enzyme/Enzyme/EnzymeLogic.cpp

+18-14
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,7 @@ void createInvertedTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
18841884
LoopContext loopContext;
18851885
BasicBlock *BB = cast<BasicBlock>(gutils->getNewFromOriginal(oBB));
18861886
bool inLoop = gutils->getContext(BB, loopContext);
1887-
BasicBlock *BB2 = gutils->reverseBlocks[BB];
1887+
BasicBlock *BB2 = gutils->reverseBlocks[BB].back();
18881888
assert(BB2);
18891889
IRBuilder<> Builder(BB2);
18901890
Builder.setFastMathFlags(getFast());
@@ -2031,7 +2031,8 @@ void createInvertedTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
20312031
handled = true;
20322032
if (!gutils->isConstantValue(oval)) {
20332033
BasicBlock *REB =
2034-
gutils->reverseBlocks[*loopContext.exitBlocks.begin()];
2034+
gutils->reverseBlocks[*loopContext.exitBlocks.begin()]
2035+
.back();
20352036
IRBuilder<> EB(REB);
20362037
if (REB->getTerminator())
20372038
EB.SetInsertPoint(REB->getTerminator());
@@ -2087,7 +2088,8 @@ void createInvertedTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
20872088
if (!gutils->isConstantValue(oval)) {
20882089

20892090
BasicBlock *REB =
2090-
gutils->reverseBlocks[*loopContext.exitBlocks.begin()];
2091+
gutils->reverseBlocks[*loopContext.exitBlocks.begin()]
2092+
.back();
20912093
IRBuilder<> EB(REB);
20922094
if (REB->getTerminator())
20932095
EB.SetInsertPoint(REB->getTerminator());
@@ -2181,7 +2183,7 @@ void createInvertedTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
21812183
pair.second->replaceAllUsesWith(replaceWith);
21822184
pair.second->eraseFromParent();
21832185
}
2184-
2186+
BB2 = gutils->reverseBlocks[BB].back();
21852187
Builder.SetInsertPoint(BB2);
21862188

21872189
Builder.CreateCondBr(
@@ -2211,6 +2213,7 @@ void createInvertedTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
22112213
targetToPreds[gutils->getReverseOrLatchMerge(pred, BB)].emplace_back(
22122214
std::make_pair(pred, BB));
22132215
}
2216+
BB2 = gutils->reverseBlocks[BB].back();
22142217
Builder.SetInsertPoint(BB2);
22152218
gutils->branchToCorrespondingTarget(BB, Builder, targetToPreds);
22162219
}
@@ -2580,14 +2583,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
25802583
assert(orig->getReturnValue());
25812584
assert(differetval);
25822585
if (!gutils->isConstantValue(orig->getReturnValue())) {
2583-
IRBuilder<> reverseB(gutils->reverseBlocks[BB]);
2586+
IRBuilder<> reverseB(gutils->reverseBlocks[BB].back());
25842587
gutils->setDiffe(orig->getReturnValue(), differetval, reverseB);
25852588
}
25862589
} else {
25872590
assert(retAlloca == nullptr);
25882591
}
25892592

2590-
rb.CreateBr(gutils->reverseBlocks[BB]);
2593+
rb.CreateBr(gutils->reverseBlocks[BB].front());
25912594
gutils->erase(op);
25922595
}
25932596
}
@@ -2698,9 +2701,9 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
26982701

26992702
auto Arch =
27002703
llvm::Triple(gutils->newFunc->getParent()->getTargetTriple()).getArch();
2701-
int SharedAddrSpace = Arch == Triple::amdgcn
2702-
? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local
2703-
: 3;
2704+
unsigned int SharedAddrSpace =
2705+
Arch == Triple::amdgcn ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local
2706+
: 3;
27042707

27052708
if (topLevel) {
27062709
BasicBlock *sharedBlock = nullptr;
@@ -2758,8 +2761,9 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
27582761

27592762
IRBuilder<> instbuilder(OldEntryInsts, OldEntryInsts->begin());
27602763

2761-
auto BarrierInst = Arch == Triple::amdgcn ? Intrinsic::amdgcn_s_barrier
2762-
: Intrinsic::nvvm_barrier0;
2764+
auto BarrierInst = Arch == Triple::amdgcn
2765+
? (llvm::Intrinsic::ID)Intrinsic::amdgcn_s_barrier
2766+
: (llvm::Intrinsic::ID)Intrinsic::nvvm_barrier0;
27632767
cast<CallInst>(instbuilder.CreateCall(
27642768
Intrinsic::getDeclaration(gutils->newFunc->getParent(), BarrierInst),
27652769
{}));
@@ -2789,9 +2793,9 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
27892793
(IRBuilder<>(gutils->inversionAllocs)).CreateUnreachable();
27902794
DeleteDeadBlock(gutils->inversionAllocs);
27912795
for (auto BBs : gutils->reverseBlocks) {
2792-
if (pred_begin(BBs.second) == pred_end(BBs.second)) {
2793-
(IRBuilder<>(BBs.second)).CreateUnreachable();
2794-
DeleteDeadBlock(BBs.second);
2796+
if (pred_begin(BBs.second.front()) == pred_end(BBs.second.front())) {
2797+
(IRBuilder<>(BBs.second.front())).CreateUnreachable();
2798+
DeleteDeadBlock(BBs.second.front());
27952799
}
27962800
}
27972801

0 commit comments

Comments
 (0)