xref: /llvm-project/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp (revision fe9f1a215e0bdb5308f0dc5021bd02b173a45dbc)
1a70aa7bbSRiver Riddle //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle //
9a70aa7bbSRiver Riddle // This file implements a pass to pipeline data transfers.
10a70aa7bbSRiver Riddle //
11a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
12a70aa7bbSRiver Riddle 
1367d0d7acSMichele Scuttari #include "mlir/Dialect/Affine/Passes.h"
1467d0d7acSMichele Scuttari 
15a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
16a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/Utils.h"
18a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
19a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h"
20a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Utils.h"
21abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h"
2267d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
23a70aa7bbSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
24a70aa7bbSRiver Riddle #include "mlir/IR/Builders.h"
25a70aa7bbSRiver Riddle #include "mlir/Transforms/Passes.h"
26a70aa7bbSRiver Riddle #include "llvm/ADT/DenseMap.h"
27a70aa7bbSRiver Riddle #include "llvm/Support/Debug.h"
28a70aa7bbSRiver Riddle 
2967d0d7acSMichele Scuttari namespace mlir {
304c48f016SMatthias Springer namespace affine {
3167d0d7acSMichele Scuttari #define GEN_PASS_DEF_AFFINEPIPELINEDATATRANSFER
3267d0d7acSMichele Scuttari #include "mlir/Dialect/Affine/Passes.h.inc"
334c48f016SMatthias Springer } // namespace affine
3467d0d7acSMichele Scuttari } // namespace mlir
3567d0d7acSMichele Scuttari 
36a70aa7bbSRiver Riddle #define DEBUG_TYPE "affine-pipeline-data-transfer"
37a70aa7bbSRiver Riddle 
38a70aa7bbSRiver Riddle using namespace mlir;
394c48f016SMatthias Springer using namespace mlir::affine;
40a70aa7bbSRiver Riddle 
41a70aa7bbSRiver Riddle namespace {
42a70aa7bbSRiver Riddle struct PipelineDataTransfer
434c48f016SMatthias Springer     : public affine::impl::AffinePipelineDataTransferBase<
444c48f016SMatthias Springer           PipelineDataTransfer> {
45a70aa7bbSRiver Riddle   void runOnOperation() override;
46a70aa7bbSRiver Riddle   void runOnAffineForOp(AffineForOp forOp);
47a70aa7bbSRiver Riddle 
48a70aa7bbSRiver Riddle   std::vector<AffineForOp> forOps;
49a70aa7bbSRiver Riddle };
50a70aa7bbSRiver Riddle 
51a70aa7bbSRiver Riddle } // namespace
52a70aa7bbSRiver Riddle 
53a70aa7bbSRiver Riddle /// Creates a pass to pipeline explicit movement of data across levels of the
54a70aa7bbSRiver Riddle /// memory hierarchy.
5558ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
564c48f016SMatthias Springer mlir::affine::createPipelineDataTransferPass() {
57a70aa7bbSRiver Riddle   return std::make_unique<PipelineDataTransfer>();
58a70aa7bbSRiver Riddle }
59a70aa7bbSRiver Riddle 
60a70aa7bbSRiver Riddle // Returns the position of the tag memref operand given a DMA operation.
61a70aa7bbSRiver Riddle // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
62d1ceb740SUday Bondhugula // added.
63a70aa7bbSRiver Riddle static unsigned getTagMemRefPos(Operation &dmaOp) {
64a70aa7bbSRiver Riddle   assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp)));
65a70aa7bbSRiver Riddle   if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
66a70aa7bbSRiver Riddle     return dmaStartOp.getTagMemRefOperandIndex();
67a70aa7bbSRiver Riddle   }
68a70aa7bbSRiver Riddle   // First operand for a dma finish operation.
69a70aa7bbSRiver Riddle   return 0;
70a70aa7bbSRiver Riddle }
71a70aa7bbSRiver Riddle 
72a70aa7bbSRiver Riddle /// Doubles the buffer of the supplied memref on the specified 'affine.for'
73a70aa7bbSRiver Riddle /// operation by adding a leading dimension of size two to the memref.
74a70aa7bbSRiver Riddle /// Replaces all uses of the old memref by the new one while indexing the newly
75a70aa7bbSRiver Riddle /// added dimension by the loop IV of the specified 'affine.for' operation
76a70aa7bbSRiver Riddle /// modulo 2. Returns false if such a replacement cannot be performed.
77a70aa7bbSRiver Riddle static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
78a70aa7bbSRiver Riddle   auto *forBody = forOp.getBody();
79a70aa7bbSRiver Riddle   OpBuilder bInner(forBody, forBody->begin());
80a70aa7bbSRiver Riddle 
81a70aa7bbSRiver Riddle   // Doubles the shape with a leading dimension extent of 2.
82a70aa7bbSRiver Riddle   auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
83a70aa7bbSRiver Riddle     // Add the leading dimension in the shape for the double buffer.
84a70aa7bbSRiver Riddle     ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
85a70aa7bbSRiver Riddle     SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
86a70aa7bbSRiver Riddle     newShape[0] = 2;
87a70aa7bbSRiver Riddle     std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
88a70aa7bbSRiver Riddle     return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
89a70aa7bbSRiver Riddle   };
90a70aa7bbSRiver Riddle 
915550c821STres Popp   auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
92a70aa7bbSRiver Riddle   auto newMemRefType = doubleShape(oldMemRefType);
93a70aa7bbSRiver Riddle 
94a70aa7bbSRiver Riddle   // The double buffer is allocated right before 'forOp'.
95a70aa7bbSRiver Riddle   OpBuilder bOuter(forOp);
96a70aa7bbSRiver Riddle   // Put together alloc operands for any dynamic dimensions of the memref.
97a70aa7bbSRiver Riddle   SmallVector<Value, 4> allocOperands;
98a70aa7bbSRiver Riddle   for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) {
99399638f9SAliia Khasanova     if (dim.value() == ShapedType::kDynamic)
100a70aa7bbSRiver Riddle       allocOperands.push_back(bOuter.createOrFold<memref::DimOp>(
101a70aa7bbSRiver Riddle           forOp.getLoc(), oldMemRef, dim.index()));
102a70aa7bbSRiver Riddle   }
103a70aa7bbSRiver Riddle 
104a70aa7bbSRiver Riddle   // Create and place the alloc right before the 'affine.for' operation.
105a70aa7bbSRiver Riddle   Value newMemRef = bOuter.create<memref::AllocOp>(
106a70aa7bbSRiver Riddle       forOp.getLoc(), newMemRefType, allocOperands);
107a70aa7bbSRiver Riddle 
108a70aa7bbSRiver Riddle   // Create 'iv mod 2' value to index the leading dimension.
109a70aa7bbSRiver Riddle   auto d0 = bInner.getAffineDimExpr(0);
11023b794f7SMatthias Springer   int64_t step = forOp.getStepAsInt();
111a70aa7bbSRiver Riddle   auto modTwoMap =
112a70aa7bbSRiver Riddle       AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
113a70aa7bbSRiver Riddle   auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
114a70aa7bbSRiver Riddle                                                  forOp.getInductionVar());
115a70aa7bbSRiver Riddle 
116a70aa7bbSRiver Riddle   // replaceAllMemRefUsesWith will succeed unless the forOp body has
117a70aa7bbSRiver Riddle   // non-dereferencing uses of the memref (dealloc's are fine though).
118a70aa7bbSRiver Riddle   if (failed(replaceAllMemRefUsesWith(
119a70aa7bbSRiver Riddle           oldMemRef, newMemRef,
120a70aa7bbSRiver Riddle           /*extraIndices=*/{ivModTwoOp},
121a70aa7bbSRiver Riddle           /*indexRemap=*/AffineMap(),
122a70aa7bbSRiver Riddle           /*extraOperands=*/{},
123a70aa7bbSRiver Riddle           /*symbolOperands=*/{},
124a70aa7bbSRiver Riddle           /*domOpFilter=*/&*forOp.getBody()->begin()))) {
125a70aa7bbSRiver Riddle     LLVM_DEBUG(
126a70aa7bbSRiver Riddle         forOp.emitError("memref replacement for double buffering failed"));
127a70aa7bbSRiver Riddle     ivModTwoOp.erase();
128a70aa7bbSRiver Riddle     return false;
129a70aa7bbSRiver Riddle   }
130a70aa7bbSRiver Riddle   // Insert the dealloc op right after the for loop.
131a70aa7bbSRiver Riddle   bOuter.setInsertionPointAfter(forOp);
132a70aa7bbSRiver Riddle   bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef);
133a70aa7bbSRiver Riddle 
134a70aa7bbSRiver Riddle   return true;
135a70aa7bbSRiver Riddle }
136a70aa7bbSRiver Riddle 
137a70aa7bbSRiver Riddle /// Returns success if the IR is in a valid state.
138a70aa7bbSRiver Riddle void PipelineDataTransfer::runOnOperation() {
139a70aa7bbSRiver Riddle   // Do a post order walk so that inner loop DMAs are processed first. This is
140a70aa7bbSRiver Riddle   // necessary since 'affine.for' operations nested within would otherwise
141a70aa7bbSRiver Riddle   // become invalid (erased) when the outer loop is pipelined (the pipelined one
142a70aa7bbSRiver Riddle   // gets deleted and replaced by a prologue, a new steady-state loop and an
143a70aa7bbSRiver Riddle   // epilogue).
144a70aa7bbSRiver Riddle   forOps.clear();
145a70aa7bbSRiver Riddle   getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
146a70aa7bbSRiver Riddle   for (auto forOp : forOps)
147a70aa7bbSRiver Riddle     runOnAffineForOp(forOp);
148a70aa7bbSRiver Riddle }
149a70aa7bbSRiver Riddle 
150a70aa7bbSRiver Riddle // Check if tags of the dma start op and dma wait op match.
151a70aa7bbSRiver Riddle static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
152a70aa7bbSRiver Riddle   if (startOp.getTagMemRef() != waitOp.getTagMemRef())
153a70aa7bbSRiver Riddle     return false;
154a70aa7bbSRiver Riddle   auto startIndices = startOp.getTagIndices();
155a70aa7bbSRiver Riddle   auto waitIndices = waitOp.getTagIndices();
156a70aa7bbSRiver Riddle   // Both of these have the same number of indices since they correspond to the
157a70aa7bbSRiver Riddle   // same tag memref.
158a70aa7bbSRiver Riddle   for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
159a70aa7bbSRiver Riddle             e = startIndices.end();
160a70aa7bbSRiver Riddle        it != e; ++it, ++wIt) {
161a70aa7bbSRiver Riddle     // Keep it simple for now, just checking if indices match.
162a70aa7bbSRiver Riddle     // TODO: this would in general need to check if there is no
163a70aa7bbSRiver Riddle     // intervening write writing to the same tag location, i.e., memory last
164a70aa7bbSRiver Riddle     // write/data flow analysis. This is however sufficient/powerful enough for
165a70aa7bbSRiver Riddle     // now since the DMA generation pass or the input for it will always have
166a70aa7bbSRiver Riddle     // start/wait with matching tags (same SSA operand indices).
167a70aa7bbSRiver Riddle     if (*it != *wIt)
168a70aa7bbSRiver Riddle       return false;
169a70aa7bbSRiver Riddle   }
170a70aa7bbSRiver Riddle   return true;
171a70aa7bbSRiver Riddle }
172a70aa7bbSRiver Riddle 
173a70aa7bbSRiver Riddle // Identify matching DMA start/finish operations to overlap computation with.
174a70aa7bbSRiver Riddle static void findMatchingStartFinishInsts(
175a70aa7bbSRiver Riddle     AffineForOp forOp,
176a70aa7bbSRiver Riddle     SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
177a70aa7bbSRiver Riddle 
178a70aa7bbSRiver Riddle   // Collect outgoing DMA operations - needed to check for dependences below.
179a70aa7bbSRiver Riddle   SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
180a70aa7bbSRiver Riddle   for (auto &op : *forOp.getBody()) {
181a70aa7bbSRiver Riddle     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
182a70aa7bbSRiver Riddle     if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
183a70aa7bbSRiver Riddle       outgoingDmaOps.push_back(dmaStartOp);
184a70aa7bbSRiver Riddle   }
185a70aa7bbSRiver Riddle 
186a70aa7bbSRiver Riddle   SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
187a70aa7bbSRiver Riddle   for (auto &op : *forOp.getBody()) {
188a70aa7bbSRiver Riddle     // Collect DMA finish operations.
189a70aa7bbSRiver Riddle     if (isa<AffineDmaWaitOp>(op)) {
190a70aa7bbSRiver Riddle       dmaFinishInsts.push_back(&op);
191a70aa7bbSRiver Riddle       continue;
192a70aa7bbSRiver Riddle     }
193a70aa7bbSRiver Riddle     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
194a70aa7bbSRiver Riddle     if (!dmaStartOp)
195a70aa7bbSRiver Riddle       continue;
196a70aa7bbSRiver Riddle 
197a70aa7bbSRiver Riddle     // Only DMAs incoming into higher memory spaces are pipelined for now.
198a70aa7bbSRiver Riddle     // TODO: handle outgoing DMA pipelining.
199a70aa7bbSRiver Riddle     if (!dmaStartOp.isDestMemorySpaceFaster())
200a70aa7bbSRiver Riddle       continue;
201a70aa7bbSRiver Riddle 
202a70aa7bbSRiver Riddle     // Check for dependence with outgoing DMAs. Doing this conservatively.
203a70aa7bbSRiver Riddle     // TODO: use the dependence analysis to check for
204a70aa7bbSRiver Riddle     // dependences between an incoming and outgoing DMA in the same iteration.
205a70aa7bbSRiver Riddle     auto *it = outgoingDmaOps.begin();
206a70aa7bbSRiver Riddle     for (; it != outgoingDmaOps.end(); ++it) {
207a70aa7bbSRiver Riddle       if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
208a70aa7bbSRiver Riddle         break;
209a70aa7bbSRiver Riddle     }
210a70aa7bbSRiver Riddle     if (it != outgoingDmaOps.end())
211a70aa7bbSRiver Riddle       continue;
212a70aa7bbSRiver Riddle 
213a70aa7bbSRiver Riddle     // We only double buffer if the buffer is not live out of loop.
214a70aa7bbSRiver Riddle     auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
215a70aa7bbSRiver Riddle     bool escapingUses = false;
216a70aa7bbSRiver Riddle     for (auto *user : memref.getUsers()) {
217a70aa7bbSRiver Riddle       // We can double buffer regardless of dealloc's outside the loop.
218a70aa7bbSRiver Riddle       if (isa<memref::DeallocOp>(user))
219a70aa7bbSRiver Riddle         continue;
220a70aa7bbSRiver Riddle       if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
221a70aa7bbSRiver Riddle         LLVM_DEBUG(llvm::dbgs()
222a70aa7bbSRiver Riddle                        << "can't pipeline: buffer is live out of loop\n";);
223a70aa7bbSRiver Riddle         escapingUses = true;
224a70aa7bbSRiver Riddle         break;
225a70aa7bbSRiver Riddle       }
226a70aa7bbSRiver Riddle     }
227a70aa7bbSRiver Riddle     if (!escapingUses)
228a70aa7bbSRiver Riddle       dmaStartInsts.push_back(&op);
229a70aa7bbSRiver Riddle   }
230a70aa7bbSRiver Riddle 
231a70aa7bbSRiver Riddle   // For each start operation, we look for a matching finish operation.
232a70aa7bbSRiver Riddle   for (auto *dmaStartOp : dmaStartInsts) {
233a70aa7bbSRiver Riddle     for (auto *dmaFinishOp : dmaFinishInsts) {
234a70aa7bbSRiver Riddle       if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp),
235a70aa7bbSRiver Riddle                         cast<AffineDmaWaitOp>(dmaFinishOp))) {
236a70aa7bbSRiver Riddle         startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
237a70aa7bbSRiver Riddle         break;
238a70aa7bbSRiver Riddle       }
239a70aa7bbSRiver Riddle     }
240a70aa7bbSRiver Riddle   }
241a70aa7bbSRiver Riddle }
242a70aa7bbSRiver Riddle 
243a70aa7bbSRiver Riddle /// Overlap DMA transfers with computation in this loop. If successful,
244a70aa7bbSRiver Riddle /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
245a70aa7bbSRiver Riddle /// inserted right before where it was.
246a70aa7bbSRiver Riddle void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
247a70aa7bbSRiver Riddle   auto mayBeConstTripCount = getConstantTripCount(forOp);
248037f0995SKazu Hirata   if (!mayBeConstTripCount) {
249a70aa7bbSRiver Riddle     LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
250a70aa7bbSRiver Riddle     return;
251a70aa7bbSRiver Riddle   }
252a70aa7bbSRiver Riddle 
253a70aa7bbSRiver Riddle   SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
254a70aa7bbSRiver Riddle   findMatchingStartFinishInsts(forOp, startWaitPairs);
255a70aa7bbSRiver Riddle 
256a70aa7bbSRiver Riddle   if (startWaitPairs.empty()) {
257a70aa7bbSRiver Riddle     LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
258a70aa7bbSRiver Riddle     return;
259a70aa7bbSRiver Riddle   }
260a70aa7bbSRiver Riddle 
261a70aa7bbSRiver Riddle   // Double the buffers for the higher memory space memref's.
262a70aa7bbSRiver Riddle   // Identify memref's to replace by scanning through all DMA start
263a70aa7bbSRiver Riddle   // operations. A DMA start operation has two memref's - the one from the
264a70aa7bbSRiver Riddle   // higher level of memory hierarchy is the one to double buffer.
265a70aa7bbSRiver Riddle   // TODO: check whether double-buffering is even necessary.
266a70aa7bbSRiver Riddle   // TODO: make this work with different layouts: assuming here that
267a70aa7bbSRiver Riddle   // the dimension we are adding here for the double buffering is the outermost
268a70aa7bbSRiver Riddle   // dimension.
269a70aa7bbSRiver Riddle   for (auto &pair : startWaitPairs) {
270a70aa7bbSRiver Riddle     auto *dmaStartOp = pair.first;
271a70aa7bbSRiver Riddle     Value oldMemRef = dmaStartOp->getOperand(
272a70aa7bbSRiver Riddle         cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
273a70aa7bbSRiver Riddle     if (!doubleBuffer(oldMemRef, forOp)) {
274a70aa7bbSRiver Riddle       // Normally, double buffering should not fail because we already checked
275a70aa7bbSRiver Riddle       // that there are no uses outside.
276a70aa7bbSRiver Riddle       LLVM_DEBUG(llvm::dbgs()
277a70aa7bbSRiver Riddle                      << "double buffering failed for" << dmaStartOp << "\n";);
278a70aa7bbSRiver Riddle       // IR still valid and semantically correct.
279a70aa7bbSRiver Riddle       return;
280a70aa7bbSRiver Riddle     }
281a70aa7bbSRiver Riddle     // If the old memref has no more uses, remove its 'dead' alloc if it was
282a70aa7bbSRiver Riddle     // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
283a70aa7bbSRiver Riddle     // operation could have been used on it if it was dynamically shaped in
284a70aa7bbSRiver Riddle     // order to create the double buffer above.)
285a70aa7bbSRiver Riddle     // '-canonicalize' does this in a more general way, but we'll anyway do the
286a70aa7bbSRiver Riddle     // simple/common case so that the output / test cases looks clear.
287a70aa7bbSRiver Riddle     if (auto *allocOp = oldMemRef.getDefiningOp()) {
288a70aa7bbSRiver Riddle       if (oldMemRef.use_empty()) {
289a70aa7bbSRiver Riddle         allocOp->erase();
290a70aa7bbSRiver Riddle       } else if (oldMemRef.hasOneUse()) {
291a70aa7bbSRiver Riddle         if (auto dealloc =
292a70aa7bbSRiver Riddle                 dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) {
293a70aa7bbSRiver Riddle           dealloc.erase();
294a70aa7bbSRiver Riddle           allocOp->erase();
295a70aa7bbSRiver Riddle         }
296a70aa7bbSRiver Riddle       }
297a70aa7bbSRiver Riddle     }
298a70aa7bbSRiver Riddle   }
299a70aa7bbSRiver Riddle 
300a70aa7bbSRiver Riddle   // Double the buffers for tag memrefs.
301a70aa7bbSRiver Riddle   for (auto &pair : startWaitPairs) {
302a70aa7bbSRiver Riddle     auto *dmaFinishOp = pair.second;
303a70aa7bbSRiver Riddle     Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp));
304a70aa7bbSRiver Riddle     if (!doubleBuffer(oldTagMemRef, forOp)) {
305a70aa7bbSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
306a70aa7bbSRiver Riddle       return;
307a70aa7bbSRiver Riddle     }
308a70aa7bbSRiver Riddle     // If the old tag has no uses or a single dealloc use, remove it.
309a70aa7bbSRiver Riddle     // (canonicalization handles more complex cases).
310a70aa7bbSRiver Riddle     if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
311a70aa7bbSRiver Riddle       if (oldTagMemRef.use_empty()) {
312a70aa7bbSRiver Riddle         tagAllocOp->erase();
313a70aa7bbSRiver Riddle       } else if (oldTagMemRef.hasOneUse()) {
314a70aa7bbSRiver Riddle         if (auto dealloc =
315a70aa7bbSRiver Riddle                 dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) {
316a70aa7bbSRiver Riddle           dealloc.erase();
317a70aa7bbSRiver Riddle           tagAllocOp->erase();
318a70aa7bbSRiver Riddle         }
319a70aa7bbSRiver Riddle       }
320a70aa7bbSRiver Riddle     }
321a70aa7bbSRiver Riddle   }
322a70aa7bbSRiver Riddle 
323a70aa7bbSRiver Riddle   // Double buffering would have invalidated all the old DMA start/wait insts.
324a70aa7bbSRiver Riddle   startWaitPairs.clear();
325a70aa7bbSRiver Riddle   findMatchingStartFinishInsts(forOp, startWaitPairs);
326a70aa7bbSRiver Riddle 
327a70aa7bbSRiver Riddle   // Store shift for operation for later lookup for AffineApplyOp's.
328a70aa7bbSRiver Riddle   DenseMap<Operation *, unsigned> instShiftMap;
329a70aa7bbSRiver Riddle   for (auto &pair : startWaitPairs) {
330a70aa7bbSRiver Riddle     auto *dmaStartOp = pair.first;
331a70aa7bbSRiver Riddle     assert(isa<AffineDmaStartOp>(dmaStartOp));
332a70aa7bbSRiver Riddle     instShiftMap[dmaStartOp] = 0;
333a70aa7bbSRiver Riddle     // Set shifts for DMA start op's affine operand computation slices to 0.
334a70aa7bbSRiver Riddle     SmallVector<AffineApplyOp, 4> sliceOps;
3354c48f016SMatthias Springer     affine::createAffineComputationSlice(dmaStartOp, &sliceOps);
336a70aa7bbSRiver Riddle     if (!sliceOps.empty()) {
337a70aa7bbSRiver Riddle       for (auto sliceOp : sliceOps) {
338a70aa7bbSRiver Riddle         instShiftMap[sliceOp.getOperation()] = 0;
339a70aa7bbSRiver Riddle       }
340a70aa7bbSRiver Riddle     } else {
341a70aa7bbSRiver Riddle       // If a slice wasn't created, the reachable affine.apply op's from its
342a70aa7bbSRiver Riddle       // operands are the ones that go with it.
343a70aa7bbSRiver Riddle       SmallVector<Operation *, 4> affineApplyInsts;
344a70aa7bbSRiver Riddle       SmallVector<Value, 4> operands(dmaStartOp->getOperands());
345a70aa7bbSRiver Riddle       getReachableAffineApplyOps(operands, affineApplyInsts);
346a70aa7bbSRiver Riddle       for (auto *op : affineApplyInsts) {
347a70aa7bbSRiver Riddle         instShiftMap[op] = 0;
348a70aa7bbSRiver Riddle       }
349a70aa7bbSRiver Riddle     }
350a70aa7bbSRiver Riddle   }
351a70aa7bbSRiver Riddle   // Everything else (including compute ops and dma finish) are shifted by one.
352a70aa7bbSRiver Riddle   for (auto &op : forOp.getBody()->without_terminator())
353*fe9f1a21SKazu Hirata     instShiftMap.try_emplace(&op, 1);
354a70aa7bbSRiver Riddle 
355a70aa7bbSRiver Riddle   // Get shifts stored in map.
356a70aa7bbSRiver Riddle   SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
357a70aa7bbSRiver Riddle   unsigned s = 0;
358a70aa7bbSRiver Riddle   for (auto &op : forOp.getBody()->without_terminator()) {
3594e585e51SKazu Hirata     assert(instShiftMap.contains(&op));
360a70aa7bbSRiver Riddle     shifts[s++] = instShiftMap[&op];
361a70aa7bbSRiver Riddle 
362a70aa7bbSRiver Riddle     // Tagging operations with shifts for debugging purposes.
363a70aa7bbSRiver Riddle     LLVM_DEBUG({
364a70aa7bbSRiver Riddle       OpBuilder b(&op);
365a70aa7bbSRiver Riddle       op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
366a70aa7bbSRiver Riddle     });
367a70aa7bbSRiver Riddle   }
368a70aa7bbSRiver Riddle 
369a70aa7bbSRiver Riddle   if (!isOpwiseShiftValid(forOp, shifts)) {
370a70aa7bbSRiver Riddle     // Violates dependences.
371a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
372a70aa7bbSRiver Riddle     return;
373a70aa7bbSRiver Riddle   }
374a70aa7bbSRiver Riddle 
375a70aa7bbSRiver Riddle   if (failed(affineForOpBodySkew(forOp, shifts))) {
376a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
377a70aa7bbSRiver Riddle     return;
378a70aa7bbSRiver Riddle   }
379a70aa7bbSRiver Riddle }
380