xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp (revision 5fcf907b34355980f77d7665a175b05fea7a6b7b)
1b1357fe6SThomas Raoux //===----------- MultiBuffering.cpp ---------------------------------------===//
2b1357fe6SThomas Raoux //
3b1357fe6SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b1357fe6SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
5b1357fe6SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b1357fe6SThomas Raoux //
7b1357fe6SThomas Raoux //===----------------------------------------------------------------------===//
8b1357fe6SThomas Raoux //
9b1357fe6SThomas Raoux // This file implements multi buffering transformation.
10b1357fe6SThomas Raoux //
11b1357fe6SThomas Raoux //===----------------------------------------------------------------------===//
12b1357fe6SThomas Raoux 
13b1357fe6SThomas Raoux #include "mlir/Dialect/Affine/IR/AffineOps.h"
14c888a0ceSNicolas Vasilache #include "mlir/Dialect/Arith/Utils/Utils.h"
15b1357fe6SThomas Raoux #include "mlir/Dialect/MemRef/IR/MemRef.h"
16b1357fe6SThomas Raoux #include "mlir/Dialect/MemRef/Transforms/Passes.h"
17faafd26cSQuentin Colombet #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
18c888a0ceSNicolas Vasilache #include "mlir/IR/AffineExpr.h"
19c888a0ceSNicolas Vasilache #include "mlir/IR/BuiltinAttributes.h"
20b1357fe6SThomas Raoux #include "mlir/IR/Dominance.h"
21c888a0ceSNicolas Vasilache #include "mlir/IR/PatternMatch.h"
22c888a0ceSNicolas Vasilache #include "mlir/IR/ValueRange.h"
23b1357fe6SThomas Raoux #include "mlir/Interfaces/LoopLikeInterface.h"
24c888a0ceSNicolas Vasilache #include "llvm/ADT/STLExtras.h"
25ec640382SNicolas Vasilache #include "llvm/Support/Debug.h"
26b1357fe6SThomas Raoux 
27b1357fe6SThomas Raoux using namespace mlir;
28b1357fe6SThomas Raoux 
29ec640382SNicolas Vasilache #define DEBUG_TYPE "memref-transforms"
30ec640382SNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
31ec640382SNicolas Vasilache #define DBGSNL() (llvm::dbgs() << "\n")
32ec640382SNicolas Vasilache 
33b1357fe6SThomas Raoux /// Return true if the op fully overwrite the given `buffer` value.
overrideBuffer(Operation * op,Value buffer)34b1357fe6SThomas Raoux static bool overrideBuffer(Operation *op, Value buffer) {
35b1357fe6SThomas Raoux   auto copyOp = dyn_cast<memref::CopyOp>(op);
36b1357fe6SThomas Raoux   if (!copyOp)
37b1357fe6SThomas Raoux     return false;
38136d746eSJacques Pienaar   return copyOp.getTarget() == buffer;
39b1357fe6SThomas Raoux }
40b1357fe6SThomas Raoux 
41b1357fe6SThomas Raoux /// Replace the uses of `oldOp` with the given `val` and for subview uses
42b1357fe6SThomas Raoux /// propagate the type change. Changing the memref type may require propagating
43b1357fe6SThomas Raoux /// it through subview ops so we cannot just do a replaceAllUse but need to
44b1357fe6SThomas Raoux /// propagate the type change and erase old subview ops.
replaceUsesAndPropagateType(RewriterBase & rewriter,Operation * oldOp,Value val)45c888a0ceSNicolas Vasilache static void replaceUsesAndPropagateType(RewriterBase &rewriter,
46c888a0ceSNicolas Vasilache                                         Operation *oldOp, Value val) {
47c888a0ceSNicolas Vasilache   SmallVector<Operation *> opsToDelete;
48b1357fe6SThomas Raoux   SmallVector<OpOperand *> operandsToReplace;
49c888a0ceSNicolas Vasilache 
50c888a0ceSNicolas Vasilache   // Save the operand to replace / delete later (avoid iterator invalidation).
51c888a0ceSNicolas Vasilache   // TODO: can we use an early_inc iterator?
52b1357fe6SThomas Raoux   for (OpOperand &use : oldOp->getUses()) {
53c888a0ceSNicolas Vasilache     // Non-subview ops will be replaced by `val`.
54b1357fe6SThomas Raoux     auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
55b1357fe6SThomas Raoux     if (!subviewUse) {
56b1357fe6SThomas Raoux       operandsToReplace.push_back(&use);
57b1357fe6SThomas Raoux       continue;
58b1357fe6SThomas Raoux     }
59c888a0ceSNicolas Vasilache 
60c888a0ceSNicolas Vasilache     // `subview(old_op)` is replaced by a new `subview(val)`.
61c888a0ceSNicolas Vasilache     OpBuilder::InsertionGuard g(rewriter);
62c888a0ceSNicolas Vasilache     rewriter.setInsertionPoint(subviewUse);
63b1357fe6SThomas Raoux     Type newType = memref::SubViewOp::inferRankReducedResultType(
645550c821STres Popp         subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
65a9733b8aSLorenzo Chelini         subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
66a9733b8aSLorenzo Chelini         subviewUse.getStaticStrides());
67c888a0ceSNicolas Vasilache     Value newSubview = rewriter.create<memref::SubViewOp>(
685550c821STres Popp         subviewUse->getLoc(), cast<MemRefType>(newType), val,
69b1357fe6SThomas Raoux         subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
70b1357fe6SThomas Raoux         subviewUse.getMixedStrides());
71c888a0ceSNicolas Vasilache 
72c888a0ceSNicolas Vasilache     // Ouch recursion ... is this really necessary?
73c888a0ceSNicolas Vasilache     replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
74c888a0ceSNicolas Vasilache 
75c888a0ceSNicolas Vasilache     opsToDelete.push_back(use.getOwner());
76b1357fe6SThomas Raoux   }
77b1357fe6SThomas Raoux 
78c888a0ceSNicolas Vasilache   // Perform late replacement.
79c888a0ceSNicolas Vasilache   // TODO: can we use an early_inc iterator?
80c888a0ceSNicolas Vasilache   for (OpOperand *operand : operandsToReplace) {
81c888a0ceSNicolas Vasilache     Operation *op = operand->getOwner();
82*5fcf907bSMatthias Springer     rewriter.startOpModification(op);
83c888a0ceSNicolas Vasilache     operand->set(val);
84*5fcf907bSMatthias Springer     rewriter.finalizeOpModification(op);
85c888a0ceSNicolas Vasilache   }
86c888a0ceSNicolas Vasilache 
87c888a0ceSNicolas Vasilache   // Perform late op erasure.
88c888a0ceSNicolas Vasilache   // TODO: can we use an early_inc iterator?
89c888a0ceSNicolas Vasilache   for (Operation *op : opsToDelete)
90c888a0ceSNicolas Vasilache     rewriter.eraseOp(op);
91b1357fe6SThomas Raoux }
92b1357fe6SThomas Raoux 
93b1357fe6SThomas Raoux // Transformation to do multi-buffering/array expansion to remove dependencies
94b1357fe6SThomas Raoux // on the temporary allocation between consecutive loop iterations.
95b1357fe6SThomas Raoux // Returns success if the transformation happened and failure otherwise.
96b1357fe6SThomas Raoux // This is not a pattern as it requires propagating the new memref type to its
97b1357fe6SThomas Raoux // uses and requires updating subview ops.
98ddeb55abSThomas Raoux FailureOr<memref::AllocOp>
multiBuffer(RewriterBase & rewriter,memref::AllocOp allocOp,unsigned multiBufferingFactor,bool skipOverrideAnalysis)99c888a0ceSNicolas Vasilache mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
100c888a0ceSNicolas Vasilache                           unsigned multiBufferingFactor,
101ddeb55abSThomas Raoux                           bool skipOverrideAnalysis) {
102c888a0ceSNicolas Vasilache   LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n");
103b1357fe6SThomas Raoux   DominanceInfo dom(allocOp->getParentOp());
104b1357fe6SThomas Raoux   LoopLikeOpInterface candidateLoop;
105b1357fe6SThomas Raoux   for (Operation *user : allocOp->getUsers()) {
106b1357fe6SThomas Raoux     auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
107ec640382SNicolas Vasilache     if (!parentLoop) {
108c888a0ceSNicolas Vasilache       if (isa<memref::DeallocOp>(user)) {
109c888a0ceSNicolas Vasilache         // Allow dealloc outside of any loop.
110c888a0ceSNicolas Vasilache         // TODO: The whole precondition function here is very brittle and will
111c888a0ceSNicolas Vasilache         // need to rethought an isolated into a cleaner analysis.
112c888a0ceSNicolas Vasilache         continue;
113c888a0ceSNicolas Vasilache       }
114c888a0ceSNicolas Vasilache       LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n");
115c888a0ceSNicolas Vasilache       LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n");
116b1357fe6SThomas Raoux       return failure();
117ec640382SNicolas Vasilache     }
118ddeb55abSThomas Raoux     if (!skipOverrideAnalysis) {
119ec640382SNicolas Vasilache       /// Make sure there is no loop-carried dependency on the allocation.
120ec640382SNicolas Vasilache       if (!overrideBuffer(user, allocOp.getResult())) {
121c888a0ceSNicolas Vasilache         LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n");
122b1357fe6SThomas Raoux         continue;
123ec640382SNicolas Vasilache       }
124b1357fe6SThomas Raoux       // If this user doesn't dominate all the other users keep looking.
125b1357fe6SThomas Raoux       if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
126b1357fe6SThomas Raoux             return !dom.dominates(user, otherUser);
127ec640382SNicolas Vasilache           })) {
128c888a0ceSNicolas Vasilache         LLVM_DEBUG(
129c888a0ceSNicolas Vasilache             DBGS() << "--Skip user: does not dominate all other users\n");
130b1357fe6SThomas Raoux         continue;
131ec640382SNicolas Vasilache       }
132ddeb55abSThomas Raoux     } else {
133ddeb55abSThomas Raoux       if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
134ddeb55abSThomas Raoux             return !isa<memref::DeallocOp>(otherUser) &&
135ddeb55abSThomas Raoux                    !parentLoop->isProperAncestor(otherUser);
136ddeb55abSThomas Raoux           })) {
137ddeb55abSThomas Raoux         LLVM_DEBUG(
138ddeb55abSThomas Raoux             DBGS()
139c888a0ceSNicolas Vasilache             << "--Skip user: not all other users are in the parent loop\n");
140ddeb55abSThomas Raoux         continue;
141ddeb55abSThomas Raoux       }
142ddeb55abSThomas Raoux     }
143b1357fe6SThomas Raoux     candidateLoop = parentLoop;
144b1357fe6SThomas Raoux     break;
145b1357fe6SThomas Raoux   }
146c888a0ceSNicolas Vasilache 
147ec640382SNicolas Vasilache   if (!candidateLoop) {
148ec640382SNicolas Vasilache     LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n");
149b1357fe6SThomas Raoux     return failure();
150ec640382SNicolas Vasilache   }
151c888a0ceSNicolas Vasilache 
15222426110SRamkumar Ramachandra   std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
15322426110SRamkumar Ramachandra   std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
15422426110SRamkumar Ramachandra   std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
1559b5ef2beSMatthias Springer   if (!inductionVar || !lowerBound || !singleStep ||
1569b5ef2beSMatthias Springer       !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
1579b5ef2beSMatthias Springer     LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
158b1357fe6SThomas Raoux     return failure();
159ec640382SNicolas Vasilache   }
160a8aeb651SKirsten Lee 
161ec640382SNicolas Vasilache   if (!dom.dominates(allocOp.getOperation(), candidateLoop)) {
162ec640382SNicolas Vasilache     LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n");
163a8aeb651SKirsten Lee     return failure();
164ec640382SNicolas Vasilache   }
165a8aeb651SKirsten Lee 
166c888a0ceSNicolas Vasilache   LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n");
167c888a0ceSNicolas Vasilache 
168c888a0ceSNicolas Vasilache   // 1. Construct the multi-buffered memref type.
169c888a0ceSNicolas Vasilache   ArrayRef<int64_t> originalShape = allocOp.getType().getShape();
170c888a0ceSNicolas Vasilache   SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor};
171c888a0ceSNicolas Vasilache   llvm::append_range(multiBufferedShape, originalShape);
172c888a0ceSNicolas Vasilache   LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n");
173c888a0ceSNicolas Vasilache   MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType())
174c888a0ceSNicolas Vasilache                                 .setShape(multiBufferedShape)
175c888a0ceSNicolas Vasilache                                 .setLayout(MemRefLayoutAttrInterface());
176c888a0ceSNicolas Vasilache   LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n");
177c888a0ceSNicolas Vasilache 
178c888a0ceSNicolas Vasilache   // 2. Create the multi-buffered alloc.
179b1357fe6SThomas Raoux   Location loc = allocOp->getLoc();
180c888a0ceSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
181c888a0ceSNicolas Vasilache   rewriter.setInsertionPoint(allocOp);
182c888a0ceSNicolas Vasilache   auto mbAlloc = rewriter.create<memref::AllocOp>(
183c888a0ceSNicolas Vasilache       loc, mbMemRefType, ValueRange{}, allocOp->getAttrs());
184c888a0ceSNicolas Vasilache   LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
185a8aeb651SKirsten Lee 
186c888a0ceSNicolas Vasilache   // 3. Within the loop, build the modular leading index (i.e. each loop
187c888a0ceSNicolas Vasilache   // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
1889b5ef2beSMatthias Springer   rewriter.setInsertionPointToStart(
1899b5ef2beSMatthias Springer       &candidateLoop.getLoopRegions().front()->front());
190c888a0ceSNicolas Vasilache   Value ivVal = *inductionVar;
191c888a0ceSNicolas Vasilache   Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound);
192c888a0ceSNicolas Vasilache   Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep);
193c888a0ceSNicolas Vasilache   AffineExpr iv, lb, step;
194c888a0ceSNicolas Vasilache   bindDims(rewriter.getContext(), iv, lb, step);
1954c48f016SMatthias Springer   Value bufferIndex = affine::makeComposedAffineApply(
196c888a0ceSNicolas Vasilache       rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
197c888a0ceSNicolas Vasilache       {ivVal, lbVal, stepVal});
198c888a0ceSNicolas Vasilache   LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n");
199a8aeb651SKirsten Lee 
200c888a0ceSNicolas Vasilache   // 4. Build the subview accessing the particular slice, taking modular
201c888a0ceSNicolas Vasilache   // rotation into account.
202c888a0ceSNicolas Vasilache   int64_t mbMemRefTypeRank = mbMemRefType.getRank();
203c888a0ceSNicolas Vasilache   IntegerAttr zero = rewriter.getIndexAttr(0);
204c888a0ceSNicolas Vasilache   IntegerAttr one = rewriter.getIndexAttr(1);
205c888a0ceSNicolas Vasilache   SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero);
206c888a0ceSNicolas Vasilache   SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one);
207c888a0ceSNicolas Vasilache   SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one);
208c888a0ceSNicolas Vasilache   // Offset is [bufferIndex, 0 ... 0 ].
209c888a0ceSNicolas Vasilache   offsets.front() = bufferIndex;
210c888a0ceSNicolas Vasilache   // Sizes is [1, original_size_0 ... original_size_n ].
211c888a0ceSNicolas Vasilache   for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
212c888a0ceSNicolas Vasilache     sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
213c888a0ceSNicolas Vasilache   // Strides is [1, 1 ... 1 ].
2145550c821STres Popp   auto dstMemref =
2155550c821STres Popp       cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
2165550c821STres Popp           originalShape, mbMemRefType, offsets, sizes, strides));
217c888a0ceSNicolas Vasilache   Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
218b1357fe6SThomas Raoux                                                      offsets, sizes, strides);
219c888a0ceSNicolas Vasilache   LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
220c888a0ceSNicolas Vasilache 
221c888a0ceSNicolas Vasilache   // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
222c888a0ceSNicolas Vasilache   // handle dealloc uses separately..
223c888a0ceSNicolas Vasilache   for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
224c888a0ceSNicolas Vasilache     auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
225c888a0ceSNicolas Vasilache     if (!deallocOp)
226c888a0ceSNicolas Vasilache       continue;
227c888a0ceSNicolas Vasilache     OpBuilder::InsertionGuard g(rewriter);
228c888a0ceSNicolas Vasilache     rewriter.setInsertionPoint(deallocOp);
229c888a0ceSNicolas Vasilache     auto newDeallocOp =
230c888a0ceSNicolas Vasilache         rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
231c888a0ceSNicolas Vasilache     (void)newDeallocOp;
232c888a0ceSNicolas Vasilache     LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
233c888a0ceSNicolas Vasilache     rewriter.eraseOp(deallocOp);
234c888a0ceSNicolas Vasilache   }
235c888a0ceSNicolas Vasilache 
236c888a0ceSNicolas Vasilache   // 6. RAUW with the particular slice, taking modular rotation into account.
237c888a0ceSNicolas Vasilache   replaceUsesAndPropagateType(rewriter, allocOp, subview);
238c888a0ceSNicolas Vasilache 
239c888a0ceSNicolas Vasilache   // 7. Finally, erase the old allocOp.
240c888a0ceSNicolas Vasilache   rewriter.eraseOp(allocOp);
241c888a0ceSNicolas Vasilache 
242c888a0ceSNicolas Vasilache   return mbAlloc;
243c888a0ceSNicolas Vasilache }
244c888a0ceSNicolas Vasilache 
245c888a0ceSNicolas Vasilache FailureOr<memref::AllocOp>
multiBuffer(memref::AllocOp allocOp,unsigned multiBufferingFactor,bool skipOverrideAnalysis)246c888a0ceSNicolas Vasilache mlir::memref::multiBuffer(memref::AllocOp allocOp,
247c888a0ceSNicolas Vasilache                           unsigned multiBufferingFactor,
248c888a0ceSNicolas Vasilache                           bool skipOverrideAnalysis) {
249c888a0ceSNicolas Vasilache   IRRewriter rewriter(allocOp->getContext());
250c888a0ceSNicolas Vasilache   return multiBuffer(rewriter, allocOp, multiBufferingFactor,
251c888a0ceSNicolas Vasilache                      skipOverrideAnalysis);
252b1357fe6SThomas Raoux }
253