xref: /llvm-project/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp (revision fe9f1a215e0bdb5308f0dc5021bd02b173a45dbc)
1 //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to pipeline data transfers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Passes.h"
14 
15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17 #include "mlir/Dialect/Affine/Analysis/Utils.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Affine/Utils.h"
21 #include "mlir/Dialect/Arith/Utils/Utils.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/MemRef/IR/MemRef.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/Transforms/Passes.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/Support/Debug.h"
28 
29 namespace mlir {
30 namespace affine {
31 #define GEN_PASS_DEF_AFFINEPIPELINEDATATRANSFER
32 #include "mlir/Dialect/Affine/Passes.h.inc"
33 } // namespace affine
34 } // namespace mlir
35 
36 #define DEBUG_TYPE "affine-pipeline-data-transfer"
37 
38 using namespace mlir;
39 using namespace mlir::affine;
40 
41 namespace {
42 struct PipelineDataTransfer
43     : public affine::impl::AffinePipelineDataTransferBase<
44           PipelineDataTransfer> {
45   void runOnOperation() override;
46   void runOnAffineForOp(AffineForOp forOp);
47 
48   std::vector<AffineForOp> forOps;
49 };
50 
51 } // namespace
52 
53 /// Creates a pass to pipeline explicit movement of data across levels of the
54 /// memory hierarchy.
55 std::unique_ptr<OperationPass<func::FuncOp>>
56 mlir::affine::createPipelineDataTransferPass() {
57   return std::make_unique<PipelineDataTransfer>();
58 }
59 
60 // Returns the position of the tag memref operand given a DMA operation.
61 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
62 // added.
63 static unsigned getTagMemRefPos(Operation &dmaOp) {
64   assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp)));
65   if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
66     return dmaStartOp.getTagMemRefOperandIndex();
67   }
68   // First operand for a dma finish operation.
69   return 0;
70 }
71 
72 /// Doubles the buffer of the supplied memref on the specified 'affine.for'
73 /// operation by adding a leading dimension of size two to the memref.
74 /// Replaces all uses of the old memref by the new one while indexing the newly
75 /// added dimension by the loop IV of the specified 'affine.for' operation
76 /// modulo 2. Returns false if such a replacement cannot be performed.
77 static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
78   auto *forBody = forOp.getBody();
79   OpBuilder bInner(forBody, forBody->begin());
80 
81   // Doubles the shape with a leading dimension extent of 2.
82   auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
83     // Add the leading dimension in the shape for the double buffer.
84     ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
85     SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
86     newShape[0] = 2;
87     std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
88     return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
89   };
90 
91   auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
92   auto newMemRefType = doubleShape(oldMemRefType);
93 
94   // The double buffer is allocated right before 'forOp'.
95   OpBuilder bOuter(forOp);
96   // Put together alloc operands for any dynamic dimensions of the memref.
97   SmallVector<Value, 4> allocOperands;
98   for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) {
99     if (dim.value() == ShapedType::kDynamic)
100       allocOperands.push_back(bOuter.createOrFold<memref::DimOp>(
101           forOp.getLoc(), oldMemRef, dim.index()));
102   }
103 
104   // Create and place the alloc right before the 'affine.for' operation.
105   Value newMemRef = bOuter.create<memref::AllocOp>(
106       forOp.getLoc(), newMemRefType, allocOperands);
107 
108   // Create 'iv mod 2' value to index the leading dimension.
109   auto d0 = bInner.getAffineDimExpr(0);
110   int64_t step = forOp.getStepAsInt();
111   auto modTwoMap =
112       AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
113   auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
114                                                  forOp.getInductionVar());
115 
116   // replaceAllMemRefUsesWith will succeed unless the forOp body has
117   // non-dereferencing uses of the memref (dealloc's are fine though).
118   if (failed(replaceAllMemRefUsesWith(
119           oldMemRef, newMemRef,
120           /*extraIndices=*/{ivModTwoOp},
121           /*indexRemap=*/AffineMap(),
122           /*extraOperands=*/{},
123           /*symbolOperands=*/{},
124           /*domOpFilter=*/&*forOp.getBody()->begin()))) {
125     LLVM_DEBUG(
126         forOp.emitError("memref replacement for double buffering failed"));
127     ivModTwoOp.erase();
128     return false;
129   }
130   // Insert the dealloc op right after the for loop.
131   bOuter.setInsertionPointAfter(forOp);
132   bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef);
133 
134   return true;
135 }
136 
137 /// Returns success if the IR is in a valid state.
138 void PipelineDataTransfer::runOnOperation() {
139   // Do a post order walk so that inner loop DMAs are processed first. This is
140   // necessary since 'affine.for' operations nested within would otherwise
141   // become invalid (erased) when the outer loop is pipelined (the pipelined one
142   // gets deleted and replaced by a prologue, a new steady-state loop and an
143   // epilogue).
144   forOps.clear();
145   getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
146   for (auto forOp : forOps)
147     runOnAffineForOp(forOp);
148 }
149 
150 // Check if tags of the dma start op and dma wait op match.
151 static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
152   if (startOp.getTagMemRef() != waitOp.getTagMemRef())
153     return false;
154   auto startIndices = startOp.getTagIndices();
155   auto waitIndices = waitOp.getTagIndices();
156   // Both of these have the same number of indices since they correspond to the
157   // same tag memref.
158   for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
159             e = startIndices.end();
160        it != e; ++it, ++wIt) {
161     // Keep it simple for now, just checking if indices match.
162     // TODO: this would in general need to check if there is no
163     // intervening write writing to the same tag location, i.e., memory last
164     // write/data flow analysis. This is however sufficient/powerful enough for
165     // now since the DMA generation pass or the input for it will always have
166     // start/wait with matching tags (same SSA operand indices).
167     if (*it != *wIt)
168       return false;
169   }
170   return true;
171 }
172 
173 // Identify matching DMA start/finish operations to overlap computation with.
174 static void findMatchingStartFinishInsts(
175     AffineForOp forOp,
176     SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
177 
178   // Collect outgoing DMA operations - needed to check for dependences below.
179   SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
180   for (auto &op : *forOp.getBody()) {
181     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
182     if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
183       outgoingDmaOps.push_back(dmaStartOp);
184   }
185 
186   SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
187   for (auto &op : *forOp.getBody()) {
188     // Collect DMA finish operations.
189     if (isa<AffineDmaWaitOp>(op)) {
190       dmaFinishInsts.push_back(&op);
191       continue;
192     }
193     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
194     if (!dmaStartOp)
195       continue;
196 
197     // Only DMAs incoming into higher memory spaces are pipelined for now.
198     // TODO: handle outgoing DMA pipelining.
199     if (!dmaStartOp.isDestMemorySpaceFaster())
200       continue;
201 
202     // Check for dependence with outgoing DMAs. Doing this conservatively.
203     // TODO: use the dependence analysis to check for
204     // dependences between an incoming and outgoing DMA in the same iteration.
205     auto *it = outgoingDmaOps.begin();
206     for (; it != outgoingDmaOps.end(); ++it) {
207       if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
208         break;
209     }
210     if (it != outgoingDmaOps.end())
211       continue;
212 
213     // We only double buffer if the buffer is not live out of loop.
214     auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
215     bool escapingUses = false;
216     for (auto *user : memref.getUsers()) {
217       // We can double buffer regardless of dealloc's outside the loop.
218       if (isa<memref::DeallocOp>(user))
219         continue;
220       if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
221         LLVM_DEBUG(llvm::dbgs()
222                        << "can't pipeline: buffer is live out of loop\n";);
223         escapingUses = true;
224         break;
225       }
226     }
227     if (!escapingUses)
228       dmaStartInsts.push_back(&op);
229   }
230 
231   // For each start operation, we look for a matching finish operation.
232   for (auto *dmaStartOp : dmaStartInsts) {
233     for (auto *dmaFinishOp : dmaFinishInsts) {
234       if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp),
235                         cast<AffineDmaWaitOp>(dmaFinishOp))) {
236         startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
237         break;
238       }
239     }
240   }
241 }
242 
243 /// Overlap DMA transfers with computation in this loop. If successful,
244 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
245 /// inserted right before where it was.
246 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
247   auto mayBeConstTripCount = getConstantTripCount(forOp);
248   if (!mayBeConstTripCount) {
249     LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
250     return;
251   }
252 
253   SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
254   findMatchingStartFinishInsts(forOp, startWaitPairs);
255 
256   if (startWaitPairs.empty()) {
257     LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
258     return;
259   }
260 
261   // Double the buffers for the higher memory space memref's.
262   // Identify memref's to replace by scanning through all DMA start
263   // operations. A DMA start operation has two memref's - the one from the
264   // higher level of memory hierarchy is the one to double buffer.
265   // TODO: check whether double-buffering is even necessary.
266   // TODO: make this work with different layouts: assuming here that
267   // the dimension we are adding here for the double buffering is the outermost
268   // dimension.
269   for (auto &pair : startWaitPairs) {
270     auto *dmaStartOp = pair.first;
271     Value oldMemRef = dmaStartOp->getOperand(
272         cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
273     if (!doubleBuffer(oldMemRef, forOp)) {
274       // Normally, double buffering should not fail because we already checked
275       // that there are no uses outside.
276       LLVM_DEBUG(llvm::dbgs()
277                      << "double buffering failed for" << dmaStartOp << "\n";);
278       // IR still valid and semantically correct.
279       return;
280     }
281     // If the old memref has no more uses, remove its 'dead' alloc if it was
282     // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
283     // operation could have been used on it if it was dynamically shaped in
284     // order to create the double buffer above.)
285     // '-canonicalize' does this in a more general way, but we'll anyway do the
286     // simple/common case so that the output / test cases looks clear.
287     if (auto *allocOp = oldMemRef.getDefiningOp()) {
288       if (oldMemRef.use_empty()) {
289         allocOp->erase();
290       } else if (oldMemRef.hasOneUse()) {
291         if (auto dealloc =
292                 dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) {
293           dealloc.erase();
294           allocOp->erase();
295         }
296       }
297     }
298   }
299 
300   // Double the buffers for tag memrefs.
301   for (auto &pair : startWaitPairs) {
302     auto *dmaFinishOp = pair.second;
303     Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp));
304     if (!doubleBuffer(oldTagMemRef, forOp)) {
305       LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
306       return;
307     }
308     // If the old tag has no uses or a single dealloc use, remove it.
309     // (canonicalization handles more complex cases).
310     if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
311       if (oldTagMemRef.use_empty()) {
312         tagAllocOp->erase();
313       } else if (oldTagMemRef.hasOneUse()) {
314         if (auto dealloc =
315                 dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) {
316           dealloc.erase();
317           tagAllocOp->erase();
318         }
319       }
320     }
321   }
322 
323   // Double buffering would have invalidated all the old DMA start/wait insts.
324   startWaitPairs.clear();
325   findMatchingStartFinishInsts(forOp, startWaitPairs);
326 
327   // Store shift for operation for later lookup for AffineApplyOp's.
328   DenseMap<Operation *, unsigned> instShiftMap;
329   for (auto &pair : startWaitPairs) {
330     auto *dmaStartOp = pair.first;
331     assert(isa<AffineDmaStartOp>(dmaStartOp));
332     instShiftMap[dmaStartOp] = 0;
333     // Set shifts for DMA start op's affine operand computation slices to 0.
334     SmallVector<AffineApplyOp, 4> sliceOps;
335     affine::createAffineComputationSlice(dmaStartOp, &sliceOps);
336     if (!sliceOps.empty()) {
337       for (auto sliceOp : sliceOps) {
338         instShiftMap[sliceOp.getOperation()] = 0;
339       }
340     } else {
341       // If a slice wasn't created, the reachable affine.apply op's from its
342       // operands are the ones that go with it.
343       SmallVector<Operation *, 4> affineApplyInsts;
344       SmallVector<Value, 4> operands(dmaStartOp->getOperands());
345       getReachableAffineApplyOps(operands, affineApplyInsts);
346       for (auto *op : affineApplyInsts) {
347         instShiftMap[op] = 0;
348       }
349     }
350   }
351   // Everything else (including compute ops and dma finish) are shifted by one.
352   for (auto &op : forOp.getBody()->without_terminator())
353     instShiftMap.try_emplace(&op, 1);
354 
355   // Get shifts stored in map.
356   SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
357   unsigned s = 0;
358   for (auto &op : forOp.getBody()->without_terminator()) {
359     assert(instShiftMap.contains(&op));
360     shifts[s++] = instShiftMap[&op];
361 
362     // Tagging operations with shifts for debugging purposes.
363     LLVM_DEBUG({
364       OpBuilder b(&op);
365       op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
366     });
367   }
368 
369   if (!isOpwiseShiftValid(forOp, shifts)) {
370     // Violates dependences.
371     LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
372     return;
373   }
374 
375   if (failed(affineForOpBodySkew(forOp, shifts))) {
376     LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
377     return;
378   }
379 }
380