@@ -2083,25 +2083,23 @@ void createTerminator(DiffeGradientUtils *gutils,
2083
2083
DIFFE_TYPE retType) {
2084
2084
2085
2085
BasicBlock *nBB = cast<BasicBlock>(gutils->getNewFromOriginal (oBB));
2086
- BasicBlock *rBB = gutils->reverseBlocks [nBB].back ();
2087
- assert (rBB);
2086
+ assert (nBB);
2088
2087
IRBuilder<> nBuilder (nBB);
2089
- IRBuilder<> rBuilder (rBB);
2090
- rBuilder.setFastMathFlags (getFast ());
2088
+ nBuilder.setFastMathFlags (getFast ());
2091
2089
2092
2090
if (ReturnInst *inst = dyn_cast_or_null<ReturnInst>(oBB->getTerminator ())) {
2093
2091
SmallVector<Value *, 4 > retargs;
2094
2092
2095
2093
if (retAlloca) {
2096
- auto result = rBuilder .CreateLoad (retAlloca, " retreload" );
2094
+ auto result = nBuilder .CreateLoad (retAlloca, " retreload" );
2097
2095
// TODO reintroduce invariant load/group
2098
2096
// result->setMetadata(LLVMContext::MD_invariant_load,
2099
2097
// MDNode::get(retAlloca->getContext(), {}));
2100
2098
retargs.push_back (result);
2101
2099
}
2102
2100
2103
2101
if (dretAlloca) {
2104
- auto result = rBuilder .CreateLoad (dretAlloca, " dretreload" );
2102
+ auto result = nBuilder .CreateLoad (dretAlloca, " dretreload" );
2105
2103
// TODO reintroduce invariant load/group
2106
2104
// result->setMetadata(LLVMContext::MD_invariant_load,
2107
2105
// MDNode::get(dretAlloca->getContext(), {}));
@@ -2110,7 +2108,6 @@ void createTerminator(DiffeGradientUtils *gutils,
2110
2108
2111
2109
if (gutils->newFunc ->getReturnType ()->isVoidTy ()) {
2112
2110
assert (retargs.size () == 0 );
2113
- rBuilder.CreateRetVoid ();
2114
2111
return ;
2115
2112
}
2116
2113
@@ -2119,16 +2116,17 @@ void createTerminator(DiffeGradientUtils *gutils,
2119
2116
if (gutils->isConstantValue (retVal)) {
2120
2117
retargs.push_back (ConstantFP::get (retVal->getType (), 0.0 ));
2121
2118
} else {
2122
- retargs.push_back (gutils->diffe (retVal, rBuilder ));
2119
+ retargs.push_back (gutils->diffe (retVal, nBuilder ));
2123
2120
}
2124
2121
2125
2122
Value *toret = UndefValue::get (gutils->newFunc ->getReturnType ());
2126
2123
for (unsigned i = 0 ; i < retargs.size (); ++i) {
2127
2124
unsigned idx[] = {i};
2128
- toret = rBuilder .CreateInsertValue (toret, retargs[i], idx);
2125
+ toret = nBuilder .CreateInsertValue (toret, retargs[i], idx);
2129
2126
}
2130
- rBuilder.CreateRet (toret);
2131
2127
2128
+ gutils->erase (gutils->getNewFromOriginal (inst));
2129
+ nBuilder.CreateRet (toret);
2132
2130
return ;
2133
2131
}
2134
2132
}
@@ -2841,38 +2839,40 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2841
2839
IRBuilder<>(&gutils->newFunc ->getEntryBlock ().front ())
2842
2840
.CreateAlloca (todiff->getReturnType (), nullptr , " dtoreturn" );
2843
2841
}
2844
- for (BasicBlock &oBB : *gutils->oldFunc ) {
2845
- if (ReturnInst *orig = dyn_cast<ReturnInst>(oBB.getTerminator ())) {
2846
- ReturnInst *op = cast<ReturnInst>(gutils->getNewFromOriginal (orig));
2847
- BasicBlock *BB = op->getParent ();
2848
- IRBuilder<> rb (op);
2849
- rb.setFastMathFlags (getFast ());
2850
-
2851
- if (retAlloca) {
2852
- StoreInst *si = rb.CreateStore (
2853
- gutils->getNewFromOriginal (orig->getReturnValue ()), retAlloca);
2854
- replacedReturns[orig] = si;
2855
- }
2856
-
2857
- if (dretAlloca && !gutils->isConstantValue (orig->getReturnValue ())) {
2858
- rb.CreateStore (gutils->invertPointerM (orig->getReturnValue (), rb),
2859
- dretAlloca);
2860
- }
2861
-
2862
- if (retType == DIFFE_TYPE::OUT_DIFF &&
2863
- mode != DerivativeMode::ForwardMode) {
2864
- assert (orig->getReturnValue ());
2865
- assert (differetval);
2866
- if (!gutils->isConstantValue (orig->getReturnValue ())) {
2867
- IRBuilder<> reverseB (gutils->reverseBlocks [BB].back ());
2868
- gutils->setDiffe (orig->getReturnValue (), differetval, reverseB);
2842
+ if (mode == DerivativeMode::ReverseModeCombined ||
2843
+ mode == DerivativeMode::ReverseModeGradient) {
2844
+ for (BasicBlock &oBB : *gutils->oldFunc ) {
2845
+ if (ReturnInst *orig = dyn_cast<ReturnInst>(oBB.getTerminator ())) {
2846
+ ReturnInst *op = cast<ReturnInst>(gutils->getNewFromOriginal (orig));
2847
+ BasicBlock *BB = op->getParent ();
2848
+ IRBuilder<> rb (op);
2849
+ rb.setFastMathFlags (getFast ());
2850
+
2851
+ if (retAlloca) {
2852
+ StoreInst *si = rb.CreateStore (
2853
+ gutils->getNewFromOriginal (orig->getReturnValue ()), retAlloca);
2854
+ replacedReturns[orig] = si;
2869
2855
}
2870
- } else if (mode != DerivativeMode::ForwardMode) {
2871
- assert (retAlloca == nullptr );
2872
- }
2873
2856
2874
- rb.CreateBr (gutils->reverseBlocks [BB].front ());
2875
- gutils->erase (op);
2857
+ if (dretAlloca && !gutils->isConstantValue (orig->getReturnValue ())) {
2858
+ rb.CreateStore (gutils->invertPointerM (orig->getReturnValue (), rb),
2859
+ dretAlloca);
2860
+ }
2861
+
2862
+ if (retType == DIFFE_TYPE::OUT_DIFF) {
2863
+ assert (orig->getReturnValue ());
2864
+ assert (differetval);
2865
+ if (!gutils->isConstantValue (orig->getReturnValue ())) {
2866
+ IRBuilder<> reverseB (gutils->reverseBlocks [BB].back ());
2867
+ gutils->setDiffe (orig->getReturnValue (), differetval, reverseB);
2868
+ }
2869
+ } else {
2870
+ assert (retAlloca == nullptr );
2871
+ }
2872
+
2873
+ rb.CreateBr (gutils->reverseBlocks [BB].front ());
2874
+ gutils->erase (op);
2875
+ }
2876
2876
}
2877
2877
}
2878
2878
0 commit comments