1 //===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop software pipelining 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/SCF/IR/SCF.h" 15 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 16 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 17 #include "mlir/Dialect/SCF/Utils/Utils.h" 18 #include "mlir/IR/IRMapping.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/RegionUtils.h" 21 #include "llvm/ADT/MapVector.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/MathExtras.h" 24 25 #define DEBUG_TYPE "scf-loop-pipelining" 26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 27 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") 28 29 using namespace mlir; 30 using namespace mlir::scf; 31 32 namespace { 33 34 /// Helper to keep internal information during pipelining transformation. 35 struct LoopPipelinerInternal { 36 /// Coarse liverange information for ops used across stages. 37 struct LiverangeInfo { 38 unsigned lastUseStage = 0; 39 unsigned defStage = 0; 40 }; 41 42 protected: 43 ForOp forOp; 44 unsigned maxStage = 0; 45 DenseMap<Operation *, unsigned> stages; 46 std::vector<Operation *> opOrder; 47 Value ub; 48 Value lb; 49 Value step; 50 bool dynamicLoop; 51 PipeliningOption::AnnotationlFnType annotateFn = nullptr; 52 bool peelEpilogue; 53 PipeliningOption::PredicateOpFn predicateFn = nullptr; 54 55 // When peeling the kernel we generate several version of each value for 56 // different stage of the prologue. This map tracks the mapping between 57 // original Values in the loop and the different versions 58 // peeled from the loop. 59 DenseMap<Value, llvm::SmallVector<Value>> valueMapping; 60 61 /// Assign a value to `valueMapping`, this means `val` represents the version 62 /// `idx` of `key` in the epilogue. 63 void setValueMapping(Value key, Value el, int64_t idx); 64 65 /// Return the defining op of the given value, if the Value is an argument of 66 /// the loop return the associated defining op in the loop and its distance to 67 /// the Value. 68 std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value); 69 70 /// Return true if the schedule is possible and return false otherwise. A 71 /// schedule is correct if all definitions are scheduled before uses. 72 bool verifySchedule(); 73 74 public: 75 /// Initalize the information for the given `op`, return true if it 76 /// satisfies the pre-condition to apply pipelining. 77 bool initializeLoopInfo(ForOp op, const PipeliningOption &options); 78 /// Emits the prologue, this creates `maxStage - 1` part which will contain 79 /// operations from stages [0; i], where i is the part index. 80 LogicalResult emitPrologue(RewriterBase &rewriter); 81 /// Gather liverange information for Values that are used in a different stage 82 /// than its definition. 83 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues(); 84 scf::ForOp createKernelLoop( 85 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues, 86 RewriterBase &rewriter, 87 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap); 88 /// Emits the pipelined kernel. This clones loop operations following user 89 /// order and remaps operands defined in a different stage as their use. 90 LogicalResult createKernel( 91 scf::ForOp newForOp, 92 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues, 93 const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap, 94 RewriterBase &rewriter); 95 /// Emits the epilogue, this creates `maxStage - 1` part which will contain 96 /// operations from stages [i; maxStage], where i is the part index. 97 LogicalResult emitEpilogue(RewriterBase &rewriter, 98 llvm::SmallVector<Value> &returnValues); 99 }; 100 101 bool LoopPipelinerInternal::initializeLoopInfo( 102 ForOp op, const PipeliningOption &options) { 103 LDBG("Start initializeLoopInfo"); 104 forOp = op; 105 ub = forOp.getUpperBound(); 106 lb = forOp.getLowerBound(); 107 step = forOp.getStep(); 108 109 dynamicLoop = true; 110 auto upperBoundCst = getConstantIntValue(ub); 111 auto lowerBoundCst = getConstantIntValue(lb); 112 auto stepCst = getConstantIntValue(step); 113 if (!upperBoundCst || !lowerBoundCst || !stepCst) { 114 if (!options.supportDynamicLoops) { 115 LDBG("--dynamic loop not supported -> BAIL"); 116 return false; 117 } 118 } else { 119 int64_t ubImm = upperBoundCst.value(); 120 int64_t lbImm = lowerBoundCst.value(); 121 int64_t stepImm = stepCst.value(); 122 int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm); 123 if (numIteration > maxStage) { 124 dynamicLoop = false; 125 } else if (!options.supportDynamicLoops) { 126 LDBG("--fewer loop iterations than pipeline stages -> BAIL"); 127 return false; 128 } 129 } 130 peelEpilogue = options.peelEpilogue; 131 predicateFn = options.predicateFn; 132 if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { 133 LDBG("--no epilogue or predicate set -> BAIL"); 134 return false; 135 } 136 std::vector<std::pair<Operation *, unsigned>> schedule; 137 options.getScheduleFn(forOp, schedule); 138 if (schedule.empty()) { 139 LDBG("--empty schedule -> BAIL"); 140 return false; 141 } 142 143 opOrder.reserve(schedule.size()); 144 for (auto &opSchedule : schedule) { 145 maxStage = std::max(maxStage, opSchedule.second); 146 stages[opSchedule.first] = opSchedule.second; 147 opOrder.push_back(opSchedule.first); 148 } 149 150 // All operations need to have a stage. 151 for (Operation &op : forOp.getBody()->without_terminator()) { 152 if (!stages.contains(&op)) { 153 op.emitOpError("not assigned a pipeline stage"); 154 LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); 155 return false; 156 } 157 } 158 159 if (!verifySchedule()) { 160 LDBG("--invalid schedule: " << op << " -> BAIL"); 161 return false; 162 } 163 164 // Currently, we do not support assigning stages to ops in nested regions. The 165 // block of all operations assigned a stage should be the single `scf.for` 166 // body block. 167 for (const auto &[op, stageNum] : stages) { 168 (void)stageNum; 169 if (op == forOp.getBody()->getTerminator()) { 170 op->emitError("terminator should not be assigned a stage"); 171 LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); 172 return false; 173 } 174 if (op->getBlock() != forOp.getBody()) { 175 op->emitOpError("the owning Block of all operations assigned a stage " 176 "should be the loop body block"); 177 LDBG("--the owning Block of all operations assigned a stage " 178 "should be the loop body block: " 179 << *op << " -> BAIL"); 180 return false; 181 } 182 } 183 184 // Support only loop-carried dependencies with a distance of one iteration or 185 // those defined outside of the loop. This means that any dependency within a 186 // loop should either be on the immediately preceding iteration, the current 187 // iteration, or on variables whose values are set before entering the loop. 188 if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), 189 [this](Value operand) { 190 Operation *def = operand.getDefiningOp(); 191 return !def || 192 (!stages.contains(def) && forOp->isAncestor(def)); 193 })) { 194 LDBG("--only support loop carried dependency with a distance of 1 or " 195 "defined outside of the loop -> BAIL"); 196 return false; 197 } 198 annotateFn = options.annotateFn; 199 return true; 200 } 201 202 /// Find operands of all the nested operations within `op`. 203 static SetVector<Value> getNestedOperands(Operation *op) { 204 SetVector<Value> operands; 205 op->walk([&](Operation *nestedOp) { 206 for (Value operand : nestedOp->getOperands()) { 207 operands.insert(operand); 208 } 209 }); 210 return operands; 211 } 212 213 /// Compute unrolled cycles of each op (consumer) and verify that each op is 214 /// scheduled after its operands (producers) while adjusting for the distance 215 /// between producer and consumer. 216 bool LoopPipelinerInternal::verifySchedule() { 217 int64_t numCylesPerIter = opOrder.size(); 218 // Pre-compute the unrolled cycle of each op. 219 DenseMap<Operation *, int64_t> unrolledCyles; 220 for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) { 221 Operation *def = opOrder[cycle]; 222 auto it = stages.find(def); 223 assert(it != stages.end()); 224 int64_t stage = it->second; 225 unrolledCyles[def] = cycle + stage * numCylesPerIter; 226 } 227 for (Operation *consumer : opOrder) { 228 int64_t consumerCycle = unrolledCyles[consumer]; 229 for (Value operand : getNestedOperands(consumer)) { 230 auto [producer, distance] = getDefiningOpAndDistance(operand); 231 if (!producer) 232 continue; 233 auto it = unrolledCyles.find(producer); 234 // Skip producer coming from outside the loop. 235 if (it == unrolledCyles.end()) 236 continue; 237 int64_t producerCycle = it->second; 238 if (consumerCycle < producerCycle - numCylesPerIter * distance) { 239 consumer->emitError("operation scheduled before its operands"); 240 return false; 241 } 242 } 243 } 244 return true; 245 } 246 247 /// Clone `op` and call `callback` on the cloned op's oeprands as well as any 248 /// operands of nested ops that: 249 /// 1) aren't defined within the new op or 250 /// 2) are block arguments. 251 static Operation * 252 cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, 253 function_ref<void(OpOperand *newOperand)> callback) { 254 Operation *clone = rewriter.clone(*op); 255 clone->walk<WalkOrder::PreOrder>([&](Operation *nested) { 256 // 'clone' itself will be visited first. 257 for (OpOperand &operand : nested->getOpOperands()) { 258 Operation *def = operand.get().getDefiningOp(); 259 if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get())) 260 callback(&operand); 261 } 262 }); 263 return clone; 264 } 265 266 LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { 267 // Initialize the iteration argument to the loop initial values. 268 for (auto [arg, operand] : 269 llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { 270 setValueMapping(arg, operand.get(), 0); 271 } 272 auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); 273 Location loc = forOp.getLoc(); 274 SmallVector<Value> predicates(maxStage); 275 for (int64_t i = 0; i < maxStage; i++) { 276 if (dynamicLoop) { 277 Type t = ub.getType(); 278 // pred = ub > lb + (i * step) 279 Value iv = rewriter.create<arith::AddIOp>( 280 loc, lb, 281 rewriter.create<arith::MulIOp>( 282 loc, step, 283 rewriter.create<arith::ConstantOp>( 284 loc, rewriter.getIntegerAttr(t, i)))); 285 predicates[i] = rewriter.create<arith::CmpIOp>( 286 loc, arith::CmpIPredicate::slt, iv, ub); 287 } 288 289 // special handling for induction variable as the increment is implicit. 290 // iv = lb + i * step 291 Type t = lb.getType(); 292 Value iv = rewriter.create<arith::AddIOp>( 293 loc, lb, 294 rewriter.create<arith::MulIOp>( 295 loc, step, 296 rewriter.create<arith::ConstantOp>(loc, 297 rewriter.getIntegerAttr(t, i)))); 298 setValueMapping(forOp.getInductionVar(), iv, i); 299 for (Operation *op : opOrder) { 300 if (stages[op] > i) 301 continue; 302 Operation *newOp = 303 cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { 304 auto it = valueMapping.find(newOperand->get()); 305 if (it != valueMapping.end()) { 306 Value replacement = it->second[i - stages[op]]; 307 newOperand->set(replacement); 308 } 309 }); 310 int predicateIdx = i - stages[op]; 311 if (predicates[predicateIdx]) { 312 OpBuilder::InsertionGuard insertGuard(rewriter); 313 newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); 314 if (newOp == nullptr) 315 return failure(); 316 } 317 if (annotateFn) 318 annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); 319 for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { 320 Value source = newOp->getResult(destId); 321 // If the value is a loop carried dependency update the loop argument 322 for (OpOperand &operand : yield->getOpOperands()) { 323 if (operand.get() != op->getResult(destId)) 324 continue; 325 if (predicates[predicateIdx] && 326 !forOp.getResult(operand.getOperandNumber()).use_empty()) { 327 // If the value is used outside the loop, we need to make sure we 328 // return the correct version of it. 329 Value prevValue = valueMapping 330 [forOp.getRegionIterArgs()[operand.getOperandNumber()]] 331 [i - stages[op]]; 332 source = rewriter.create<arith::SelectOp>( 333 loc, predicates[predicateIdx], source, prevValue); 334 } 335 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], 336 source, i - stages[op] + 1); 337 } 338 setValueMapping(op->getResult(destId), newOp->getResult(destId), 339 i - stages[op]); 340 } 341 } 342 } 343 return success(); 344 } 345 346 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 347 LoopPipelinerInternal::analyzeCrossStageValues() { 348 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues; 349 for (Operation *op : opOrder) { 350 unsigned stage = stages[op]; 351 352 auto analyzeOperand = [&](OpOperand &operand) { 353 auto [def, distance] = getDefiningOpAndDistance(operand.get()); 354 if (!def) 355 return; 356 auto defStage = stages.find(def); 357 if (defStage == stages.end() || defStage->second == stage || 358 defStage->second == stage + distance) 359 return; 360 assert(stage > defStage->second); 361 LiverangeInfo &info = crossStageValues[operand.get()]; 362 info.defStage = defStage->second; 363 info.lastUseStage = std::max(info.lastUseStage, stage); 364 }; 365 366 for (OpOperand &operand : op->getOpOperands()) 367 analyzeOperand(operand); 368 visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { 369 analyzeOperand(*operand); 370 }); 371 } 372 return crossStageValues; 373 } 374 375 std::pair<Operation *, int64_t> 376 LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { 377 int64_t distance = 0; 378 if (auto arg = dyn_cast<BlockArgument>(value)) { 379 if (arg.getOwner() != forOp.getBody()) 380 return {nullptr, 0}; 381 // Ignore induction variable. 382 if (arg.getArgNumber() == 0) 383 return {nullptr, 0}; 384 distance++; 385 value = 386 forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); 387 } 388 Operation *def = value.getDefiningOp(); 389 if (!def) 390 return {nullptr, 0}; 391 return {def, distance}; 392 } 393 394 scf::ForOp LoopPipelinerInternal::createKernelLoop( 395 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 396 &crossStageValues, 397 RewriterBase &rewriter, 398 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) { 399 // Creates the list of initial values associated to values used across 400 // stages. The initial values come from the prologue created above. 401 // Keep track of the kernel argument associated to each version of the 402 // values passed to the kernel. 403 llvm::SmallVector<Value> newLoopArg; 404 // For existing loop argument initialize them with the right version from the 405 // prologue. 406 for (const auto &retVal : 407 llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { 408 Operation *def = retVal.value().getDefiningOp(); 409 assert(def && "Only support loop carried dependencies of distance of 1 or " 410 "outside the loop"); 411 auto defStage = stages.find(def); 412 if (defStage != stages.end()) { 413 Value valueVersion = 414 valueMapping[forOp.getRegionIterArgs()[retVal.index()]] 415 [maxStage - defStage->second]; 416 assert(valueVersion); 417 newLoopArg.push_back(valueVersion); 418 } else 419 newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]); 420 } 421 for (auto escape : crossStageValues) { 422 LiverangeInfo &info = escape.second; 423 Value value = escape.first; 424 for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; 425 stageIdx++) { 426 Value valueVersion = 427 valueMapping[value][maxStage - info.lastUseStage + stageIdx]; 428 assert(valueVersion); 429 newLoopArg.push_back(valueVersion); 430 loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - 431 stageIdx)] = newLoopArg.size() - 1; 432 } 433 } 434 435 // Create the new kernel loop. When we peel the epilgue we need to peel 436 // `numStages - 1` iterations. Then we adjust the upper bound to remove those 437 // iterations. 438 Value newUb = forOp.getUpperBound(); 439 if (peelEpilogue) { 440 Type t = ub.getType(); 441 Location loc = forOp.getLoc(); 442 // newUb = ub - maxStage * step 443 Value maxStageValue = rewriter.create<arith::ConstantOp>( 444 loc, rewriter.getIntegerAttr(t, maxStage)); 445 Value maxStageByStep = 446 rewriter.create<arith::MulIOp>(loc, step, maxStageValue); 447 newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep); 448 } 449 auto newForOp = 450 rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb, 451 forOp.getStep(), newLoopArg); 452 // When there are no iter args, the loop body terminator will be created. 453 // Since we always create it below, remove the terminator if it was created. 454 if (!newForOp.getBody()->empty()) 455 rewriter.eraseOp(newForOp.getBody()->getTerminator()); 456 return newForOp; 457 } 458 459 LogicalResult LoopPipelinerInternal::createKernel( 460 scf::ForOp newForOp, 461 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 462 &crossStageValues, 463 const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap, 464 RewriterBase &rewriter) { 465 valueMapping.clear(); 466 467 // Create the kernel, we clone instruction based on the order given by 468 // user and remap operands coming from a previous stages. 469 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); 470 IRMapping mapping; 471 mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); 472 for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { 473 mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); 474 } 475 SmallVector<Value> predicates(maxStage + 1, nullptr); 476 if (!peelEpilogue) { 477 // Create a predicate for each stage except the last stage. 478 Location loc = newForOp.getLoc(); 479 Type t = ub.getType(); 480 for (unsigned i = 0; i < maxStage; i++) { 481 // c = ub - (maxStage - i) * step 482 Value c = rewriter.create<arith::SubIOp>( 483 loc, ub, 484 rewriter.create<arith::MulIOp>( 485 loc, step, 486 rewriter.create<arith::ConstantOp>( 487 loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i))))); 488 489 Value pred = rewriter.create<arith::CmpIOp>( 490 newForOp.getLoc(), arith::CmpIPredicate::slt, 491 newForOp.getInductionVar(), c); 492 predicates[i] = pred; 493 } 494 } 495 for (Operation *op : opOrder) { 496 int64_t useStage = stages[op]; 497 auto *newOp = rewriter.clone(*op, mapping); 498 SmallVector<OpOperand *> operands; 499 // Collect all the operands for the cloned op and its nested ops. 500 op->walk([&operands](Operation *nestedOp) { 501 for (OpOperand &operand : nestedOp->getOpOperands()) { 502 operands.push_back(&operand); 503 } 504 }); 505 for (OpOperand *operand : operands) { 506 Operation *nestedNewOp = mapping.lookup(operand->getOwner()); 507 // Special case for the induction variable uses. We replace it with a 508 // version incremented based on the stage where it is used. 509 if (operand->get() == forOp.getInductionVar()) { 510 rewriter.setInsertionPoint(newOp); 511 512 // offset = (maxStage - stages[op]) * step 513 Type t = step.getType(); 514 Value offset = rewriter.create<arith::MulIOp>( 515 forOp.getLoc(), step, 516 rewriter.create<arith::ConstantOp>( 517 forOp.getLoc(), 518 rewriter.getIntegerAttr(t, maxStage - stages[op]))); 519 Value iv = rewriter.create<arith::AddIOp>( 520 forOp.getLoc(), newForOp.getInductionVar(), offset); 521 nestedNewOp->setOperand(operand->getOperandNumber(), iv); 522 rewriter.setInsertionPointAfter(newOp); 523 continue; 524 } 525 Value source = operand->get(); 526 auto arg = dyn_cast<BlockArgument>(source); 527 if (arg && arg.getOwner() == forOp.getBody()) { 528 Value ret = forOp.getBody()->getTerminator()->getOperand( 529 arg.getArgNumber() - 1); 530 Operation *dep = ret.getDefiningOp(); 531 if (!dep) 532 continue; 533 auto stageDep = stages.find(dep); 534 if (stageDep == stages.end() || stageDep->second == useStage) 535 continue; 536 // If the value is a loop carried value coming from stage N + 1 remap, 537 // it will become a direct use. 538 if (stageDep->second == useStage + 1) { 539 nestedNewOp->setOperand(operand->getOperandNumber(), 540 mapping.lookupOrDefault(ret)); 541 continue; 542 } 543 source = ret; 544 } 545 // For operands defined in a previous stage we need to remap it to use 546 // the correct region argument. We look for the right version of the 547 // Value based on the stage where it is used. 548 Operation *def = source.getDefiningOp(); 549 if (!def) 550 continue; 551 auto stageDef = stages.find(def); 552 if (stageDef == stages.end() || stageDef->second == useStage) 553 continue; 554 auto remap = loopArgMap.find( 555 std::make_pair(operand->get(), useStage - stageDef->second)); 556 assert(remap != loopArgMap.end()); 557 nestedNewOp->setOperand(operand->getOperandNumber(), 558 newForOp.getRegionIterArgs()[remap->second]); 559 } 560 561 if (predicates[useStage]) { 562 OpBuilder::InsertionGuard insertGuard(rewriter); 563 newOp = predicateFn(rewriter, newOp, predicates[useStage]); 564 if (!newOp) 565 return failure(); 566 // Remap the results to the new predicated one. 567 for (auto values : llvm::zip(op->getResults(), newOp->getResults())) 568 mapping.map(std::get<0>(values), std::get<1>(values)); 569 } 570 if (annotateFn) 571 annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0); 572 } 573 574 // Collect the Values that need to be returned by the forOp. For each 575 // value we need to have `LastUseStage - DefStage` number of versions 576 // returned. 577 // We create a mapping between original values and the associated loop 578 // returned values that will be needed by the epilogue. 579 llvm::SmallVector<Value> yieldOperands; 580 for (OpOperand &yieldOperand : 581 forOp.getBody()->getTerminator()->getOpOperands()) { 582 Value source = mapping.lookupOrDefault(yieldOperand.get()); 583 // When we don't peel the epilogue and the yield value is used outside the 584 // loop we need to make sure we return the version from numStages - 585 // defStage. 586 if (!peelEpilogue && 587 !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { 588 Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first; 589 if (def) { 590 auto defStage = stages.find(def); 591 if (defStage != stages.end() && defStage->second < maxStage) { 592 Value pred = predicates[defStage->second]; 593 source = rewriter.create<arith::SelectOp>( 594 pred.getLoc(), pred, source, 595 newForOp.getBody() 596 ->getArguments()[yieldOperand.getOperandNumber() + 1]); 597 } 598 } 599 } 600 yieldOperands.push_back(source); 601 } 602 603 for (auto &it : crossStageValues) { 604 int64_t version = maxStage - it.second.lastUseStage + 1; 605 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; 606 // add the original version to yield ops. 607 // If there is a live range spanning across more than 2 stages we need to 608 // add extra arg. 609 for (unsigned i = 1; i < numVersionReturned; i++) { 610 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), 611 version++); 612 yieldOperands.push_back( 613 newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + 614 newForOp.getNumInductionVars()]); 615 } 616 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), 617 version++); 618 yieldOperands.push_back(mapping.lookupOrDefault(it.first)); 619 } 620 // Map the yield operand to the forOp returned value. 621 for (const auto &retVal : 622 llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { 623 Operation *def = retVal.value().getDefiningOp(); 624 assert(def && "Only support loop carried dependencies of distance of 1 or " 625 "defined outside the loop"); 626 auto defStage = stages.find(def); 627 if (defStage == stages.end()) { 628 for (unsigned int stage = 1; stage <= maxStage; stage++) 629 setValueMapping(forOp.getRegionIterArgs()[retVal.index()], 630 retVal.value(), stage); 631 } else if (defStage->second > 0) { 632 setValueMapping(forOp.getRegionIterArgs()[retVal.index()], 633 newForOp->getResult(retVal.index()), 634 maxStage - defStage->second + 1); 635 } 636 } 637 rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands); 638 return success(); 639 } 640 641 LogicalResult 642 LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, 643 llvm::SmallVector<Value> &returnValues) { 644 Location loc = forOp.getLoc(); 645 Type t = lb.getType(); 646 647 // Emit different versions of the induction variable. They will be 648 // removed by dead code if not used. 649 650 auto createConst = [&](int v) { 651 return rewriter.create<arith::ConstantOp>(loc, 652 rewriter.getIntegerAttr(t, v)); 653 }; 654 655 // total_iterations = cdiv(range_diff, step); 656 // - range_diff = ub - lb 657 // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step 658 Value zero = createConst(0); 659 Value one = createConst(1); 660 Value stepLessZero = rewriter.create<arith::CmpIOp>( 661 loc, arith::CmpIPredicate::slt, step, zero); 662 Value stepDecr = 663 rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1)); 664 665 Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb); 666 Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step); 667 Value rangeDecr = 668 rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr); 669 Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step); 670 671 // If total_iters < max_stage, start the epilogue at zero to match the 672 // ramp-up in the prologue. 673 // start_iter = max(0, total_iters - max_stage) 674 Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations, 675 createConst(maxStage)); 676 iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI); 677 678 // Capture predicates for dynamic loops. 679 SmallVector<Value> predicates(maxStage + 1); 680 681 for (int64_t i = 1; i <= maxStage; i++) { 682 // newLastIter = lb + step * iterI 683 Value newlastIter = rewriter.create<arith::AddIOp>( 684 loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI)); 685 686 setValueMapping(forOp.getInductionVar(), newlastIter, i); 687 688 // increment to next iterI 689 iterI = rewriter.create<arith::AddIOp>(loc, iterI, one); 690 691 if (dynamicLoop) { 692 // Disable stages when `i` is greater than total_iters. 693 // pred = total_iters >= i 694 predicates[i] = rewriter.create<arith::CmpIOp>( 695 loc, arith::CmpIPredicate::sge, totalIterations, createConst(i)); 696 } 697 } 698 699 // Emit `maxStage - 1` epilogue part that includes operations from stages 700 // [i; maxStage]. 701 for (int64_t i = 1; i <= maxStage; i++) { 702 SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size()); 703 for (Operation *op : opOrder) { 704 if (stages[op] < i) 705 continue; 706 unsigned currentVersion = maxStage - stages[op] + i; 707 unsigned nextVersion = currentVersion + 1; 708 Operation *newOp = 709 cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { 710 auto it = valueMapping.find(newOperand->get()); 711 if (it != valueMapping.end()) { 712 Value replacement = it->second[currentVersion]; 713 newOperand->set(replacement); 714 } 715 }); 716 if (dynamicLoop) { 717 OpBuilder::InsertionGuard insertGuard(rewriter); 718 newOp = predicateFn(rewriter, newOp, predicates[currentVersion]); 719 if (!newOp) 720 return failure(); 721 } 722 if (annotateFn) 723 annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1); 724 725 for (auto [opRes, newRes] : 726 llvm::zip(op->getResults(), newOp->getResults())) { 727 setValueMapping(opRes, newRes, currentVersion); 728 // If the value is a loop carried dependency update the loop argument 729 // mapping and keep track of the last version to replace the original 730 // forOp uses. 731 for (OpOperand &operand : 732 forOp.getBody()->getTerminator()->getOpOperands()) { 733 if (operand.get() != opRes) 734 continue; 735 // If the version is greater than maxStage it means it maps to the 736 // original forOp returned value. 737 unsigned ri = operand.getOperandNumber(); 738 returnValues[ri] = newRes; 739 Value mapVal = forOp.getRegionIterArgs()[ri]; 740 returnMap[ri] = std::make_pair(mapVal, currentVersion); 741 if (nextVersion <= maxStage) 742 setValueMapping(mapVal, newRes, nextVersion); 743 } 744 } 745 } 746 if (dynamicLoop) { 747 // Select return values from this stage (live outs) based on predication. 748 // If the stage is valid select the peeled value, else use previous stage 749 // value. 750 for (auto pair : llvm::enumerate(returnValues)) { 751 unsigned ri = pair.index(); 752 auto [mapVal, currentVersion] = returnMap[ri]; 753 if (mapVal) { 754 unsigned nextVersion = currentVersion + 1; 755 Value pred = predicates[currentVersion]; 756 Value prevValue = valueMapping[mapVal][currentVersion]; 757 auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(), 758 prevValue); 759 returnValues[ri] = selOp; 760 if (nextVersion <= maxStage) 761 setValueMapping(mapVal, selOp, nextVersion); 762 } 763 } 764 } 765 } 766 return success(); 767 } 768 769 void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { 770 auto it = valueMapping.find(key); 771 // If the value is not in the map yet add a vector big enough to store all 772 // versions. 773 if (it == valueMapping.end()) 774 it = 775 valueMapping 776 .insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1))) 777 .first; 778 it->second[idx] = el; 779 } 780 781 } // namespace 782 783 FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, 784 const PipeliningOption &options, 785 bool *modifiedIR) { 786 if (modifiedIR) 787 *modifiedIR = false; 788 LoopPipelinerInternal pipeliner; 789 if (!pipeliner.initializeLoopInfo(forOp, options)) 790 return failure(); 791 792 if (modifiedIR) 793 *modifiedIR = true; 794 795 // 1. Emit prologue. 796 if (failed(pipeliner.emitPrologue(rewriter))) 797 return failure(); 798 799 // 2. Track values used across stages. When a value cross stages it will 800 // need to be passed as loop iteration arguments. 801 // We first collect the values that are used in a different stage than where 802 // they are defined. 803 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 804 crossStageValues = pipeliner.analyzeCrossStageValues(); 805 806 // Mapping between original loop values used cross stage and the block 807 // arguments associated after pipelining. A Value may map to several 808 // arguments if its liverange spans across more than 2 stages. 809 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap; 810 // 3. Create the new kernel loop and return the block arguments mapping. 811 ForOp newForOp = 812 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); 813 // Create the kernel block, order ops based on user choice and remap 814 // operands. 815 if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, 816 rewriter))) 817 return failure(); 818 819 llvm::SmallVector<Value> returnValues = 820 newForOp.getResults().take_front(forOp->getNumResults()); 821 if (options.peelEpilogue) { 822 // 4. Emit the epilogue after the new forOp. 823 rewriter.setInsertionPointAfter(newForOp); 824 if (failed(pipeliner.emitEpilogue(rewriter, returnValues))) 825 return failure(); 826 } 827 // 5. Erase the original loop and replace the uses with the epilogue output. 828 if (forOp->getNumResults() > 0) 829 rewriter.replaceOp(forOp, returnValues); 830 else 831 rewriter.eraseOp(forOp); 832 833 return newForOp; 834 } 835 836 void mlir::scf::populateSCFLoopPipeliningPatterns( 837 RewritePatternSet &patterns, const PipeliningOption &options) { 838 patterns.add<ForLoopPipeliningPattern>(options, patterns.getContext()); 839 } 840