xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp (revision 5fcf907b34355980f77d7665a175b05fea7a6b7b)
1 //===----------- MultiBuffering.cpp ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements multi buffering transformation.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/Utils/Utils.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
17 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/Dominance.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/ValueRange.h"
23 #include "mlir/Interfaces/LoopLikeInterface.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/Debug.h"
26 
27 using namespace mlir;
28 
29 #define DEBUG_TYPE "memref-transforms"
30 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
31 #define DBGSNL() (llvm::dbgs() << "\n")
32 
33 /// Return true if the op fully overwrite the given `buffer` value.
overrideBuffer(Operation * op,Value buffer)34 static bool overrideBuffer(Operation *op, Value buffer) {
35   auto copyOp = dyn_cast<memref::CopyOp>(op);
36   if (!copyOp)
37     return false;
38   return copyOp.getTarget() == buffer;
39 }
40 
41 /// Replace the uses of `oldOp` with the given `val` and for subview uses
42 /// propagate the type change. Changing the memref type may require propagating
43 /// it through subview ops so we cannot just do a replaceAllUse but need to
44 /// propagate the type change and erase old subview ops.
replaceUsesAndPropagateType(RewriterBase & rewriter,Operation * oldOp,Value val)45 static void replaceUsesAndPropagateType(RewriterBase &rewriter,
46                                         Operation *oldOp, Value val) {
47   SmallVector<Operation *> opsToDelete;
48   SmallVector<OpOperand *> operandsToReplace;
49 
50   // Save the operand to replace / delete later (avoid iterator invalidation).
51   // TODO: can we use an early_inc iterator?
52   for (OpOperand &use : oldOp->getUses()) {
53     // Non-subview ops will be replaced by `val`.
54     auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
55     if (!subviewUse) {
56       operandsToReplace.push_back(&use);
57       continue;
58     }
59 
60     // `subview(old_op)` is replaced by a new `subview(val)`.
61     OpBuilder::InsertionGuard g(rewriter);
62     rewriter.setInsertionPoint(subviewUse);
63     Type newType = memref::SubViewOp::inferRankReducedResultType(
64         subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
65         subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
66         subviewUse.getStaticStrides());
67     Value newSubview = rewriter.create<memref::SubViewOp>(
68         subviewUse->getLoc(), cast<MemRefType>(newType), val,
69         subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
70         subviewUse.getMixedStrides());
71 
72     // Ouch recursion ... is this really necessary?
73     replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
74 
75     opsToDelete.push_back(use.getOwner());
76   }
77 
78   // Perform late replacement.
79   // TODO: can we use an early_inc iterator?
80   for (OpOperand *operand : operandsToReplace) {
81     Operation *op = operand->getOwner();
82     rewriter.startOpModification(op);
83     operand->set(val);
84     rewriter.finalizeOpModification(op);
85   }
86 
87   // Perform late op erasure.
88   // TODO: can we use an early_inc iterator?
89   for (Operation *op : opsToDelete)
90     rewriter.eraseOp(op);
91 }
92 
93 // Transformation to do multi-buffering/array expansion to remove dependencies
94 // on the temporary allocation between consecutive loop iterations.
95 // Returns success if the transformation happened and failure otherwise.
96 // This is not a pattern as it requires propagating the new memref type to its
97 // uses and requires updating subview ops.
98 FailureOr<memref::AllocOp>
multiBuffer(RewriterBase & rewriter,memref::AllocOp allocOp,unsigned multiBufferingFactor,bool skipOverrideAnalysis)99 mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
100                           unsigned multiBufferingFactor,
101                           bool skipOverrideAnalysis) {
102   LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n");
103   DominanceInfo dom(allocOp->getParentOp());
104   LoopLikeOpInterface candidateLoop;
105   for (Operation *user : allocOp->getUsers()) {
106     auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
107     if (!parentLoop) {
108       if (isa<memref::DeallocOp>(user)) {
109         // Allow dealloc outside of any loop.
110         // TODO: The whole precondition function here is very brittle and will
111         // need to rethought an isolated into a cleaner analysis.
112         continue;
113       }
114       LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n");
115       LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n");
116       return failure();
117     }
118     if (!skipOverrideAnalysis) {
119       /// Make sure there is no loop-carried dependency on the allocation.
120       if (!overrideBuffer(user, allocOp.getResult())) {
121         LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n");
122         continue;
123       }
124       // If this user doesn't dominate all the other users keep looking.
125       if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
126             return !dom.dominates(user, otherUser);
127           })) {
128         LLVM_DEBUG(
129             DBGS() << "--Skip user: does not dominate all other users\n");
130         continue;
131       }
132     } else {
133       if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
134             return !isa<memref::DeallocOp>(otherUser) &&
135                    !parentLoop->isProperAncestor(otherUser);
136           })) {
137         LLVM_DEBUG(
138             DBGS()
139             << "--Skip user: not all other users are in the parent loop\n");
140         continue;
141       }
142     }
143     candidateLoop = parentLoop;
144     break;
145   }
146 
147   if (!candidateLoop) {
148     LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n");
149     return failure();
150   }
151 
152   std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
153   std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
154   std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
155   if (!inductionVar || !lowerBound || !singleStep ||
156       !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
157     LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
158     return failure();
159   }
160 
161   if (!dom.dominates(allocOp.getOperation(), candidateLoop)) {
162     LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n");
163     return failure();
164   }
165 
166   LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n");
167 
168   // 1. Construct the multi-buffered memref type.
169   ArrayRef<int64_t> originalShape = allocOp.getType().getShape();
170   SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor};
171   llvm::append_range(multiBufferedShape, originalShape);
172   LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n");
173   MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType())
174                                 .setShape(multiBufferedShape)
175                                 .setLayout(MemRefLayoutAttrInterface());
176   LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n");
177 
178   // 2. Create the multi-buffered alloc.
179   Location loc = allocOp->getLoc();
180   OpBuilder::InsertionGuard g(rewriter);
181   rewriter.setInsertionPoint(allocOp);
182   auto mbAlloc = rewriter.create<memref::AllocOp>(
183       loc, mbMemRefType, ValueRange{}, allocOp->getAttrs());
184   LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
185 
186   // 3. Within the loop, build the modular leading index (i.e. each loop
187   // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
188   rewriter.setInsertionPointToStart(
189       &candidateLoop.getLoopRegions().front()->front());
190   Value ivVal = *inductionVar;
191   Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound);
192   Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep);
193   AffineExpr iv, lb, step;
194   bindDims(rewriter.getContext(), iv, lb, step);
195   Value bufferIndex = affine::makeComposedAffineApply(
196       rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
197       {ivVal, lbVal, stepVal});
198   LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n");
199 
200   // 4. Build the subview accessing the particular slice, taking modular
201   // rotation into account.
202   int64_t mbMemRefTypeRank = mbMemRefType.getRank();
203   IntegerAttr zero = rewriter.getIndexAttr(0);
204   IntegerAttr one = rewriter.getIndexAttr(1);
205   SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero);
206   SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one);
207   SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one);
208   // Offset is [bufferIndex, 0 ... 0 ].
209   offsets.front() = bufferIndex;
210   // Sizes is [1, original_size_0 ... original_size_n ].
211   for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
212     sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
213   // Strides is [1, 1 ... 1 ].
214   auto dstMemref =
215       cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
216           originalShape, mbMemRefType, offsets, sizes, strides));
217   Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
218                                                      offsets, sizes, strides);
219   LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
220 
221   // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
222   // handle dealloc uses separately..
223   for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
224     auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
225     if (!deallocOp)
226       continue;
227     OpBuilder::InsertionGuard g(rewriter);
228     rewriter.setInsertionPoint(deallocOp);
229     auto newDeallocOp =
230         rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
231     (void)newDeallocOp;
232     LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
233     rewriter.eraseOp(deallocOp);
234   }
235 
236   // 6. RAUW with the particular slice, taking modular rotation into account.
237   replaceUsesAndPropagateType(rewriter, allocOp, subview);
238 
239   // 7. Finally, erase the old allocOp.
240   rewriter.eraseOp(allocOp);
241 
242   return mbAlloc;
243 }
244 
245 FailureOr<memref::AllocOp>
multiBuffer(memref::AllocOp allocOp,unsigned multiBufferingFactor,bool skipOverrideAnalysis)246 mlir::memref::multiBuffer(memref::AllocOp allocOp,
247                           unsigned multiBufferingFactor,
248                           bool skipOverrideAnalysis) {
249   IRRewriter rewriter(allocOp->getContext());
250   return multiBuffer(rewriter, allocOp, multiBufferingFactor,
251                      skipOverrideAnalysis);
252 }
253