1 //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to pipeline data transfers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Passes.h" 14 15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Affine/Utils.h" 21 #include "mlir/Dialect/Arith/Utils/Utils.h" 22 #include "mlir/Dialect/Func/IR/FuncOps.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/Transforms/Passes.h" 26 #include "llvm/ADT/DenseMap.h" 27 #include "llvm/Support/Debug.h" 28 29 namespace mlir { 30 namespace affine { 31 #define GEN_PASS_DEF_AFFINEPIPELINEDATATRANSFER 32 #include "mlir/Dialect/Affine/Passes.h.inc" 33 } // namespace affine 34 } // namespace mlir 35 36 #define DEBUG_TYPE "affine-pipeline-data-transfer" 37 38 using namespace mlir; 39 using namespace mlir::affine; 40 41 namespace { 42 struct PipelineDataTransfer 43 : public affine::impl::AffinePipelineDataTransferBase< 44 PipelineDataTransfer> { 45 void runOnOperation() override; 46 void runOnAffineForOp(AffineForOp forOp); 47 48 std::vector<AffineForOp> forOps; 49 }; 50 51 } // namespace 52 53 /// Creates a pass to pipeline explicit movement of data across levels of the 54 /// memory hierarchy. 55 std::unique_ptr<OperationPass<func::FuncOp>> 56 mlir::affine::createPipelineDataTransferPass() { 57 return std::make_unique<PipelineDataTransfer>(); 58 } 59 60 // Returns the position of the tag memref operand given a DMA operation. 61 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are 62 // added. 63 static unsigned getTagMemRefPos(Operation &dmaOp) { 64 assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp))); 65 if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) { 66 return dmaStartOp.getTagMemRefOperandIndex(); 67 } 68 // First operand for a dma finish operation. 69 return 0; 70 } 71 72 /// Doubles the buffer of the supplied memref on the specified 'affine.for' 73 /// operation by adding a leading dimension of size two to the memref. 74 /// Replaces all uses of the old memref by the new one while indexing the newly 75 /// added dimension by the loop IV of the specified 'affine.for' operation 76 /// modulo 2. Returns false if such a replacement cannot be performed. 77 static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { 78 auto *forBody = forOp.getBody(); 79 OpBuilder bInner(forBody, forBody->begin()); 80 81 // Doubles the shape with a leading dimension extent of 2. 82 auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType { 83 // Add the leading dimension in the shape for the double buffer. 84 ArrayRef<int64_t> oldShape = oldMemRefType.getShape(); 85 SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank()); 86 newShape[0] = 2; 87 std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); 88 return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({}); 89 }; 90 91 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType()); 92 auto newMemRefType = doubleShape(oldMemRefType); 93 94 // The double buffer is allocated right before 'forOp'. 95 OpBuilder bOuter(forOp); 96 // Put together alloc operands for any dynamic dimensions of the memref. 97 SmallVector<Value, 4> allocOperands; 98 for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) { 99 if (dim.value() == ShapedType::kDynamic) 100 allocOperands.push_back(bOuter.createOrFold<memref::DimOp>( 101 forOp.getLoc(), oldMemRef, dim.index())); 102 } 103 104 // Create and place the alloc right before the 'affine.for' operation. 105 Value newMemRef = bOuter.create<memref::AllocOp>( 106 forOp.getLoc(), newMemRefType, allocOperands); 107 108 // Create 'iv mod 2' value to index the leading dimension. 109 auto d0 = bInner.getAffineDimExpr(0); 110 int64_t step = forOp.getStepAsInt(); 111 auto modTwoMap = 112 AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2); 113 auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap, 114 forOp.getInductionVar()); 115 116 // replaceAllMemRefUsesWith will succeed unless the forOp body has 117 // non-dereferencing uses of the memref (dealloc's are fine though). 118 if (failed(replaceAllMemRefUsesWith( 119 oldMemRef, newMemRef, 120 /*extraIndices=*/{ivModTwoOp}, 121 /*indexRemap=*/AffineMap(), 122 /*extraOperands=*/{}, 123 /*symbolOperands=*/{}, 124 /*domOpFilter=*/&*forOp.getBody()->begin()))) { 125 LLVM_DEBUG( 126 forOp.emitError("memref replacement for double buffering failed")); 127 ivModTwoOp.erase(); 128 return false; 129 } 130 // Insert the dealloc op right after the for loop. 131 bOuter.setInsertionPointAfter(forOp); 132 bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef); 133 134 return true; 135 } 136 137 /// Returns success if the IR is in a valid state. 138 void PipelineDataTransfer::runOnOperation() { 139 // Do a post order walk so that inner loop DMAs are processed first. This is 140 // necessary since 'affine.for' operations nested within would otherwise 141 // become invalid (erased) when the outer loop is pipelined (the pipelined one 142 // gets deleted and replaced by a prologue, a new steady-state loop and an 143 // epilogue). 144 forOps.clear(); 145 getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); }); 146 for (auto forOp : forOps) 147 runOnAffineForOp(forOp); 148 } 149 150 // Check if tags of the dma start op and dma wait op match. 151 static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) { 152 if (startOp.getTagMemRef() != waitOp.getTagMemRef()) 153 return false; 154 auto startIndices = startOp.getTagIndices(); 155 auto waitIndices = waitOp.getTagIndices(); 156 // Both of these have the same number of indices since they correspond to the 157 // same tag memref. 158 for (auto it = startIndices.begin(), wIt = waitIndices.begin(), 159 e = startIndices.end(); 160 it != e; ++it, ++wIt) { 161 // Keep it simple for now, just checking if indices match. 162 // TODO: this would in general need to check if there is no 163 // intervening write writing to the same tag location, i.e., memory last 164 // write/data flow analysis. This is however sufficient/powerful enough for 165 // now since the DMA generation pass or the input for it will always have 166 // start/wait with matching tags (same SSA operand indices). 167 if (*it != *wIt) 168 return false; 169 } 170 return true; 171 } 172 173 // Identify matching DMA start/finish operations to overlap computation with. 174 static void findMatchingStartFinishInsts( 175 AffineForOp forOp, 176 SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) { 177 178 // Collect outgoing DMA operations - needed to check for dependences below. 179 SmallVector<AffineDmaStartOp, 4> outgoingDmaOps; 180 for (auto &op : *forOp.getBody()) { 181 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op); 182 if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) 183 outgoingDmaOps.push_back(dmaStartOp); 184 } 185 186 SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts; 187 for (auto &op : *forOp.getBody()) { 188 // Collect DMA finish operations. 189 if (isa<AffineDmaWaitOp>(op)) { 190 dmaFinishInsts.push_back(&op); 191 continue; 192 } 193 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op); 194 if (!dmaStartOp) 195 continue; 196 197 // Only DMAs incoming into higher memory spaces are pipelined for now. 198 // TODO: handle outgoing DMA pipelining. 199 if (!dmaStartOp.isDestMemorySpaceFaster()) 200 continue; 201 202 // Check for dependence with outgoing DMAs. Doing this conservatively. 203 // TODO: use the dependence analysis to check for 204 // dependences between an incoming and outgoing DMA in the same iteration. 205 auto *it = outgoingDmaOps.begin(); 206 for (; it != outgoingDmaOps.end(); ++it) { 207 if (it->getDstMemRef() == dmaStartOp.getSrcMemRef()) 208 break; 209 } 210 if (it != outgoingDmaOps.end()) 211 continue; 212 213 // We only double buffer if the buffer is not live out of loop. 214 auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos()); 215 bool escapingUses = false; 216 for (auto *user : memref.getUsers()) { 217 // We can double buffer regardless of dealloc's outside the loop. 218 if (isa<memref::DeallocOp>(user)) 219 continue; 220 if (!forOp.getBody()->findAncestorOpInBlock(*user)) { 221 LLVM_DEBUG(llvm::dbgs() 222 << "can't pipeline: buffer is live out of loop\n";); 223 escapingUses = true; 224 break; 225 } 226 } 227 if (!escapingUses) 228 dmaStartInsts.push_back(&op); 229 } 230 231 // For each start operation, we look for a matching finish operation. 232 for (auto *dmaStartOp : dmaStartInsts) { 233 for (auto *dmaFinishOp : dmaFinishInsts) { 234 if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp), 235 cast<AffineDmaWaitOp>(dmaFinishOp))) { 236 startWaitPairs.push_back({dmaStartOp, dmaFinishOp}); 237 break; 238 } 239 } 240 } 241 } 242 243 /// Overlap DMA transfers with computation in this loop. If successful, 244 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are 245 /// inserted right before where it was. 246 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { 247 auto mayBeConstTripCount = getConstantTripCount(forOp); 248 if (!mayBeConstTripCount) { 249 LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count")); 250 return; 251 } 252 253 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs; 254 findMatchingStartFinishInsts(forOp, startWaitPairs); 255 256 if (startWaitPairs.empty()) { 257 LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n")); 258 return; 259 } 260 261 // Double the buffers for the higher memory space memref's. 262 // Identify memref's to replace by scanning through all DMA start 263 // operations. A DMA start operation has two memref's - the one from the 264 // higher level of memory hierarchy is the one to double buffer. 265 // TODO: check whether double-buffering is even necessary. 266 // TODO: make this work with different layouts: assuming here that 267 // the dimension we are adding here for the double buffering is the outermost 268 // dimension. 269 for (auto &pair : startWaitPairs) { 270 auto *dmaStartOp = pair.first; 271 Value oldMemRef = dmaStartOp->getOperand( 272 cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos()); 273 if (!doubleBuffer(oldMemRef, forOp)) { 274 // Normally, double buffering should not fail because we already checked 275 // that there are no uses outside. 276 LLVM_DEBUG(llvm::dbgs() 277 << "double buffering failed for" << dmaStartOp << "\n";); 278 // IR still valid and semantically correct. 279 return; 280 } 281 // If the old memref has no more uses, remove its 'dead' alloc if it was 282 // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim' 283 // operation could have been used on it if it was dynamically shaped in 284 // order to create the double buffer above.) 285 // '-canonicalize' does this in a more general way, but we'll anyway do the 286 // simple/common case so that the output / test cases looks clear. 287 if (auto *allocOp = oldMemRef.getDefiningOp()) { 288 if (oldMemRef.use_empty()) { 289 allocOp->erase(); 290 } else if (oldMemRef.hasOneUse()) { 291 if (auto dealloc = 292 dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) { 293 dealloc.erase(); 294 allocOp->erase(); 295 } 296 } 297 } 298 } 299 300 // Double the buffers for tag memrefs. 301 for (auto &pair : startWaitPairs) { 302 auto *dmaFinishOp = pair.second; 303 Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp)); 304 if (!doubleBuffer(oldTagMemRef, forOp)) { 305 LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); 306 return; 307 } 308 // If the old tag has no uses or a single dealloc use, remove it. 309 // (canonicalization handles more complex cases). 310 if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) { 311 if (oldTagMemRef.use_empty()) { 312 tagAllocOp->erase(); 313 } else if (oldTagMemRef.hasOneUse()) { 314 if (auto dealloc = 315 dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) { 316 dealloc.erase(); 317 tagAllocOp->erase(); 318 } 319 } 320 } 321 } 322 323 // Double buffering would have invalidated all the old DMA start/wait insts. 324 startWaitPairs.clear(); 325 findMatchingStartFinishInsts(forOp, startWaitPairs); 326 327 // Store shift for operation for later lookup for AffineApplyOp's. 328 DenseMap<Operation *, unsigned> instShiftMap; 329 for (auto &pair : startWaitPairs) { 330 auto *dmaStartOp = pair.first; 331 assert(isa<AffineDmaStartOp>(dmaStartOp)); 332 instShiftMap[dmaStartOp] = 0; 333 // Set shifts for DMA start op's affine operand computation slices to 0. 334 SmallVector<AffineApplyOp, 4> sliceOps; 335 affine::createAffineComputationSlice(dmaStartOp, &sliceOps); 336 if (!sliceOps.empty()) { 337 for (auto sliceOp : sliceOps) { 338 instShiftMap[sliceOp.getOperation()] = 0; 339 } 340 } else { 341 // If a slice wasn't created, the reachable affine.apply op's from its 342 // operands are the ones that go with it. 343 SmallVector<Operation *, 4> affineApplyInsts; 344 SmallVector<Value, 4> operands(dmaStartOp->getOperands()); 345 getReachableAffineApplyOps(operands, affineApplyInsts); 346 for (auto *op : affineApplyInsts) { 347 instShiftMap[op] = 0; 348 } 349 } 350 } 351 // Everything else (including compute ops and dma finish) are shifted by one. 352 for (auto &op : forOp.getBody()->without_terminator()) 353 instShiftMap.try_emplace(&op, 1); 354 355 // Get shifts stored in map. 356 SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size()); 357 unsigned s = 0; 358 for (auto &op : forOp.getBody()->without_terminator()) { 359 assert(instShiftMap.contains(&op)); 360 shifts[s++] = instShiftMap[&op]; 361 362 // Tagging operations with shifts for debugging purposes. 363 LLVM_DEBUG({ 364 OpBuilder b(&op); 365 op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1])); 366 }); 367 } 368 369 if (!isOpwiseShiftValid(forOp, shifts)) { 370 // Violates dependences. 371 LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); 372 return; 373 } 374 375 if (failed(affineForOpBodySkew(forOp, shifts))) { 376 LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";); 377 return; 378 } 379 } 380