Skip to content

Commit 4f1e195

Browse files
authored
don't create rBBs for ForwardMode (rust-lang#213)
1 parent 6f78640 commit 4f1e195

File tree

2 files changed

+43
-40
lines changed

2 files changed

+43
-40
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

+40-40
Original file line numberDiff line numberDiff line change
@@ -2083,25 +2083,23 @@ void createTerminator(DiffeGradientUtils *gutils,
20832083
DIFFE_TYPE retType) {
20842084

20852085
BasicBlock *nBB = cast<BasicBlock>(gutils->getNewFromOriginal(oBB));
2086-
BasicBlock *rBB = gutils->reverseBlocks[nBB].back();
2087-
assert(rBB);
2086+
assert(nBB);
20882087
IRBuilder<> nBuilder(nBB);
2089-
IRBuilder<> rBuilder(rBB);
2090-
rBuilder.setFastMathFlags(getFast());
2088+
nBuilder.setFastMathFlags(getFast());
20912089

20922090
if (ReturnInst *inst = dyn_cast_or_null<ReturnInst>(oBB->getTerminator())) {
20932091
SmallVector<Value *, 4> retargs;
20942092

20952093
if (retAlloca) {
2096-
auto result = rBuilder.CreateLoad(retAlloca, "retreload");
2094+
auto result = nBuilder.CreateLoad(retAlloca, "retreload");
20972095
// TODO reintroduce invariant load/group
20982096
// result->setMetadata(LLVMContext::MD_invariant_load,
20992097
// MDNode::get(retAlloca->getContext(), {}));
21002098
retargs.push_back(result);
21012099
}
21022100

21032101
if (dretAlloca) {
2104-
auto result = rBuilder.CreateLoad(dretAlloca, "dretreload");
2102+
auto result = nBuilder.CreateLoad(dretAlloca, "dretreload");
21052103
// TODO reintroduce invariant load/group
21062104
// result->setMetadata(LLVMContext::MD_invariant_load,
21072105
// MDNode::get(dretAlloca->getContext(), {}));
@@ -2110,7 +2108,6 @@ void createTerminator(DiffeGradientUtils *gutils,
21102108

21112109
if (gutils->newFunc->getReturnType()->isVoidTy()) {
21122110
assert(retargs.size() == 0);
2113-
rBuilder.CreateRetVoid();
21142111
return;
21152112
}
21162113

@@ -2119,16 +2116,17 @@ void createTerminator(DiffeGradientUtils *gutils,
21192116
if (gutils->isConstantValue(retVal)) {
21202117
retargs.push_back(ConstantFP::get(retVal->getType(), 0.0));
21212118
} else {
2122-
retargs.push_back(gutils->diffe(retVal, rBuilder));
2119+
retargs.push_back(gutils->diffe(retVal, nBuilder));
21232120
}
21242121

21252122
Value *toret = UndefValue::get(gutils->newFunc->getReturnType());
21262123
for (unsigned i = 0; i < retargs.size(); ++i) {
21272124
unsigned idx[] = {i};
2128-
toret = rBuilder.CreateInsertValue(toret, retargs[i], idx);
2125+
toret = nBuilder.CreateInsertValue(toret, retargs[i], idx);
21292126
}
2130-
rBuilder.CreateRet(toret);
21312127

2128+
gutils->erase(gutils->getNewFromOriginal(inst));
2129+
nBuilder.CreateRet(toret);
21322130
return;
21332131
}
21342132
}
@@ -2841,38 +2839,40 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
28412839
IRBuilder<>(&gutils->newFunc->getEntryBlock().front())
28422840
.CreateAlloca(todiff->getReturnType(), nullptr, "dtoreturn");
28432841
}
2844-
for (BasicBlock &oBB : *gutils->oldFunc) {
2845-
if (ReturnInst *orig = dyn_cast<ReturnInst>(oBB.getTerminator())) {
2846-
ReturnInst *op = cast<ReturnInst>(gutils->getNewFromOriginal(orig));
2847-
BasicBlock *BB = op->getParent();
2848-
IRBuilder<> rb(op);
2849-
rb.setFastMathFlags(getFast());
2850-
2851-
if (retAlloca) {
2852-
StoreInst *si = rb.CreateStore(
2853-
gutils->getNewFromOriginal(orig->getReturnValue()), retAlloca);
2854-
replacedReturns[orig] = si;
2855-
}
2856-
2857-
if (dretAlloca && !gutils->isConstantValue(orig->getReturnValue())) {
2858-
rb.CreateStore(gutils->invertPointerM(orig->getReturnValue(), rb),
2859-
dretAlloca);
2860-
}
2861-
2862-
if (retType == DIFFE_TYPE::OUT_DIFF &&
2863-
mode != DerivativeMode::ForwardMode) {
2864-
assert(orig->getReturnValue());
2865-
assert(differetval);
2866-
if (!gutils->isConstantValue(orig->getReturnValue())) {
2867-
IRBuilder<> reverseB(gutils->reverseBlocks[BB].back());
2868-
gutils->setDiffe(orig->getReturnValue(), differetval, reverseB);
2842+
if (mode == DerivativeMode::ReverseModeCombined ||
2843+
mode == DerivativeMode::ReverseModeGradient) {
2844+
for (BasicBlock &oBB : *gutils->oldFunc) {
2845+
if (ReturnInst *orig = dyn_cast<ReturnInst>(oBB.getTerminator())) {
2846+
ReturnInst *op = cast<ReturnInst>(gutils->getNewFromOriginal(orig));
2847+
BasicBlock *BB = op->getParent();
2848+
IRBuilder<> rb(op);
2849+
rb.setFastMathFlags(getFast());
2850+
2851+
if (retAlloca) {
2852+
StoreInst *si = rb.CreateStore(
2853+
gutils->getNewFromOriginal(orig->getReturnValue()), retAlloca);
2854+
replacedReturns[orig] = si;
28692855
}
2870-
} else if (mode != DerivativeMode::ForwardMode) {
2871-
assert(retAlloca == nullptr);
2872-
}
28732856

2874-
rb.CreateBr(gutils->reverseBlocks[BB].front());
2875-
gutils->erase(op);
2857+
if (dretAlloca && !gutils->isConstantValue(orig->getReturnValue())) {
2858+
rb.CreateStore(gutils->invertPointerM(orig->getReturnValue(), rb),
2859+
dretAlloca);
2860+
}
2861+
2862+
if (retType == DIFFE_TYPE::OUT_DIFF) {
2863+
assert(orig->getReturnValue());
2864+
assert(differetval);
2865+
if (!gutils->isConstantValue(orig->getReturnValue())) {
2866+
IRBuilder<> reverseB(gutils->reverseBlocks[BB].back());
2867+
gutils->setDiffe(orig->getReturnValue(), differetval, reverseB);
2868+
}
2869+
} else {
2870+
assert(retAlloca == nullptr);
2871+
}
2872+
2873+
rb.CreateBr(gutils->reverseBlocks[BB].front());
2874+
gutils->erase(op);
2875+
}
28762876
}
28772877
}
28782878

enzyme/Enzyme/GradientUtils.h

+3
Original file line numberDiff line numberDiff line change
@@ -1207,6 +1207,9 @@ class DiffeGradientUtils : public GradientUtils {
12071207
constantvalues_, returnvals_, ActiveReturn, origToNew_,
12081208
mode) {
12091209
assert(reverseBlocks.size() == 0);
1210+
if (mode == DerivativeMode::ForwardMode) {
1211+
return;
1212+
}
12101213
for (BasicBlock *BB : originalBlocks) {
12111214
if (BB == inversionAllocs)
12121215
continue;

0 commit comments

Comments
 (0)