1f6f88e66Sthomasraoux //===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// 2f6f88e66Sthomasraoux // 3f6f88e66Sthomasraoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4f6f88e66Sthomasraoux // See https://llvm.org/LICENSE.txt for license information. 5f6f88e66Sthomasraoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6f6f88e66Sthomasraoux // 7f6f88e66Sthomasraoux //===----------------------------------------------------------------------===// 8f6f88e66Sthomasraoux // 9f6f88e66Sthomasraoux // This file implements loop software pipelining 10f6f88e66Sthomasraoux // 11f6f88e66Sthomasraoux //===----------------------------------------------------------------------===// 12f6f88e66Sthomasraoux 13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 148b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 158b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h" 168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h" 17f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h" 184d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 19f6f88e66Sthomasraoux #include "mlir/IR/PatternMatch.h" 20f5fe92f6SChristopher Bate #include "mlir/Transforms/RegionUtils.h" 2167d0d7acSMichele Scuttari #include "llvm/ADT/MapVector.h" 22c7592c77SNicolas Vasilache #include "llvm/Support/Debug.h" 230fb216fbSRamkumar Ramachandra #include "llvm/Support/MathExtras.h" 24c7592c77SNicolas Vasilache 25c7592c77SNicolas Vasilache #define DEBUG_TYPE "scf-loop-pipelining" 26c7592c77SNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 27c7592c77SNicolas Vasilache #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") 28f6f88e66Sthomasraoux 29f6f88e66Sthomasraoux using namespace mlir; 30f6f88e66Sthomasraoux using namespace mlir::scf; 31f6f88e66Sthomasraoux 32f6f88e66Sthomasraoux namespace { 33f6f88e66Sthomasraoux 34f6f88e66Sthomasraoux /// Helper to keep internal information during pipelining transformation. 35f6f88e66Sthomasraoux struct LoopPipelinerInternal { 36f6f88e66Sthomasraoux /// Coarse liverange information for ops used across stages. 37f6f88e66Sthomasraoux struct LiverangeInfo { 38f6f88e66Sthomasraoux unsigned lastUseStage = 0; 39f6f88e66Sthomasraoux unsigned defStage = 0; 40f6f88e66Sthomasraoux }; 41f6f88e66Sthomasraoux 42f6f88e66Sthomasraoux protected: 43f6f88e66Sthomasraoux ForOp forOp; 44f6f88e66Sthomasraoux unsigned maxStage = 0; 45f6f88e66Sthomasraoux DenseMap<Operation *, unsigned> stages; 46f6f88e66Sthomasraoux std::vector<Operation *> opOrder; 47ef112833SThomas Raoux Value ub; 48ef112833SThomas Raoux Value lb; 49ef112833SThomas Raoux Value step; 50ef112833SThomas Raoux bool dynamicLoop; 510736bbd7SThomas Raoux PipeliningOption::AnnotationlFnType annotateFn = nullptr; 52205c08b5SThomas Raoux bool peelEpilogue; 53205c08b5SThomas Raoux PipeliningOption::PredicateOpFn predicateFn = nullptr; 54f6f88e66Sthomasraoux 55f6f88e66Sthomasraoux // When peeling the kernel we generate several version of each value for 56f6f88e66Sthomasraoux // different stage of the prologue. This map tracks the mapping between 57f6f88e66Sthomasraoux // original Values in the loop and the different versions 58f6f88e66Sthomasraoux // peeled from the loop. 59f6f88e66Sthomasraoux DenseMap<Value, llvm::SmallVector<Value>> valueMapping; 60f6f88e66Sthomasraoux 61f6f88e66Sthomasraoux /// Assign a value to `valueMapping`, this means `val` represents the version 62f6f88e66Sthomasraoux /// `idx` of `key` in the epilogue. 63f6f88e66Sthomasraoux void setValueMapping(Value key, Value el, int64_t idx); 64f6f88e66Sthomasraoux 6519e068b0SThomas Raoux /// Return the defining op of the given value, if the Value is an argument of 6619e068b0SThomas Raoux /// the loop return the associated defining op in the loop and its distance to 6719e068b0SThomas Raoux /// the Value. 6819e068b0SThomas Raoux std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value); 6919e068b0SThomas Raoux 70c933bd81SThomas Raoux /// Return true if the schedule is possible and return false otherwise. A 71c933bd81SThomas Raoux /// schedule is correct if all definitions are scheduled before uses. 72c933bd81SThomas Raoux bool verifySchedule(); 73c933bd81SThomas Raoux 74f6f88e66Sthomasraoux public: 75f6f88e66Sthomasraoux /// Initalize the information for the given `op`, return true if it 76f6f88e66Sthomasraoux /// satisfies the pre-condition to apply pipelining. 77f6f88e66Sthomasraoux bool initializeLoopInfo(ForOp op, const PipeliningOption &options); 78f6f88e66Sthomasraoux /// Emits the prologue, this creates `maxStage - 1` part which will contain 79f6f88e66Sthomasraoux /// operations from stages [0; i], where i is the part index. 8018926666SSJW LogicalResult emitPrologue(RewriterBase &rewriter); 81f6f88e66Sthomasraoux /// Gather liverange information for Values that are used in a different stage 82f6f88e66Sthomasraoux /// than its definition. 83f6f88e66Sthomasraoux llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues(); 84f6f88e66Sthomasraoux scf::ForOp createKernelLoop( 85f6f88e66Sthomasraoux const llvm::MapVector<Value, LiverangeInfo> &crossStageValues, 861cff4cbdSNicolas Vasilache RewriterBase &rewriter, 87f6f88e66Sthomasraoux llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap); 88f6f88e66Sthomasraoux /// Emits the pipelined kernel. This clones loop operations following user 89f6f88e66Sthomasraoux /// order and remaps operands defined in a different stage as their use. 90371366ceSAlex Zinenko LogicalResult createKernel( 91f6f88e66Sthomasraoux scf::ForOp newForOp, 92f6f88e66Sthomasraoux const llvm::MapVector<Value, LiverangeInfo> &crossStageValues, 93f6f88e66Sthomasraoux const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap, 941cff4cbdSNicolas Vasilache RewriterBase &rewriter); 95f6f88e66Sthomasraoux /// Emits the epilogue, this creates `maxStage - 1` part which will contain 96f6f88e66Sthomasraoux /// operations from stages [i; maxStage], where i is the part index. 97ebf05993SSJW LogicalResult emitEpilogue(RewriterBase &rewriter, 98e66f97e8SKeren Zhou llvm::SmallVector<Value> &returnValues); 99f6f88e66Sthomasraoux }; 100f6f88e66Sthomasraoux 101f6f88e66Sthomasraoux bool LoopPipelinerInternal::initializeLoopInfo( 102f6f88e66Sthomasraoux ForOp op, const PipeliningOption &options) { 103c7592c77SNicolas Vasilache LDBG("Start initializeLoopInfo"); 104f6f88e66Sthomasraoux forOp = op; 105ef112833SThomas Raoux ub = forOp.getUpperBound(); 106ef112833SThomas Raoux lb = forOp.getLowerBound(); 107ef112833SThomas Raoux step = forOp.getStep(); 108ef112833SThomas Raoux 109ef112833SThomas Raoux dynamicLoop = true; 110ef112833SThomas Raoux auto upperBoundCst = getConstantIntValue(ub); 111ef112833SThomas Raoux auto lowerBoundCst = getConstantIntValue(lb); 112ef112833SThomas Raoux auto stepCst = getConstantIntValue(step); 113c7592c77SNicolas Vasilache if (!upperBoundCst || !lowerBoundCst || !stepCst) { 114ef112833SThomas Raoux if (!options.supportDynamicLoops) { 115ef112833SThomas Raoux LDBG("--dynamic loop not supported -> BAIL"); 116f6f88e66Sthomasraoux return false; 117c7592c77SNicolas Vasilache } 118ef112833SThomas Raoux } else { 119ef112833SThomas Raoux int64_t ubImm = upperBoundCst.value(); 120ef112833SThomas Raoux int64_t lbImm = lowerBoundCst.value(); 121ef112833SThomas Raoux int64_t stepImm = stepCst.value(); 1220fb216fbSRamkumar Ramachandra int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm); 123ef112833SThomas Raoux if (numIteration > maxStage) { 124ef112833SThomas Raoux dynamicLoop = false; 125ef112833SThomas Raoux } else if (!options.supportDynamicLoops) { 126ef112833SThomas Raoux LDBG("--fewer loop iterations than pipeline stages -> BAIL"); 127ef112833SThomas Raoux return false; 128ef112833SThomas Raoux } 129ef112833SThomas Raoux } 130205c08b5SThomas Raoux peelEpilogue = options.peelEpilogue; 131205c08b5SThomas Raoux predicateFn = options.predicateFn; 132ef112833SThomas Raoux if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { 133c7592c77SNicolas Vasilache LDBG("--no epilogue or predicate set -> BAIL"); 134205c08b5SThomas Raoux return false; 135c7592c77SNicolas Vasilache } 136f6f88e66Sthomasraoux std::vector<std::pair<Operation *, unsigned>> schedule; 137f6f88e66Sthomasraoux options.getScheduleFn(forOp, schedule); 138c7592c77SNicolas Vasilache if (schedule.empty()) { 139c7592c77SNicolas Vasilache LDBG("--empty schedule -> BAIL"); 140f6f88e66Sthomasraoux return false; 141c7592c77SNicolas Vasilache } 142f6f88e66Sthomasraoux 143f6f88e66Sthomasraoux opOrder.reserve(schedule.size()); 144f6f88e66Sthomasraoux for (auto &opSchedule : schedule) { 145f6f88e66Sthomasraoux maxStage = std::max(maxStage, opSchedule.second); 146f6f88e66Sthomasraoux stages[opSchedule.first] = opSchedule.second; 147f6f88e66Sthomasraoux opOrder.push_back(opSchedule.first); 148f6f88e66Sthomasraoux } 149f6f88e66Sthomasraoux 150f6f88e66Sthomasraoux // All operations need to have a stage. 151f5fe92f6SChristopher Bate for (Operation &op : forOp.getBody()->without_terminator()) { 152ce14f7b1SKazu Hirata if (!stages.contains(&op)) { 153f5fe92f6SChristopher Bate op.emitOpError("not assigned a pipeline stage"); 154c7592c77SNicolas Vasilache LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); 155f6f88e66Sthomasraoux return false; 156f5fe92f6SChristopher Bate } 157f5fe92f6SChristopher Bate } 158f5fe92f6SChristopher Bate 159c933bd81SThomas Raoux if (!verifySchedule()) { 160c933bd81SThomas Raoux LDBG("--invalid schedule: " << op << " -> BAIL"); 161c933bd81SThomas Raoux return false; 162c933bd81SThomas Raoux } 163c933bd81SThomas Raoux 164f5fe92f6SChristopher Bate // Currently, we do not support assigning stages to ops in nested regions. The 165f5fe92f6SChristopher Bate // block of all operations assigned a stage should be the single `scf.for` 166f5fe92f6SChristopher Bate // body block. 167f5fe92f6SChristopher Bate for (const auto &[op, stageNum] : stages) { 168f5fe92f6SChristopher Bate (void)stageNum; 169f5fe92f6SChristopher Bate if (op == forOp.getBody()->getTerminator()) { 170f5fe92f6SChristopher Bate op->emitError("terminator should not be assigned a stage"); 171c7592c77SNicolas Vasilache LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); 172f5fe92f6SChristopher Bate return false; 173f5fe92f6SChristopher Bate } 174f5fe92f6SChristopher Bate if (op->getBlock() != forOp.getBody()) { 175f5fe92f6SChristopher Bate op->emitOpError("the owning Block of all operations assigned a stage " 176f5fe92f6SChristopher Bate "should be the loop body block"); 177c7592c77SNicolas Vasilache LDBG("--the owning Block of all operations assigned a stage " 178c7592c77SNicolas Vasilache "should be the loop body block: " 179c7592c77SNicolas Vasilache << *op << " -> BAIL"); 180f5fe92f6SChristopher Bate return false; 181f5fe92f6SChristopher Bate } 182f5fe92f6SChristopher Bate } 183f6f88e66Sthomasraoux 184e66f97e8SKeren Zhou // Support only loop-carried dependencies with a distance of one iteration or 185e66f97e8SKeren Zhou // those defined outside of the loop. This means that any dependency within a 186e66f97e8SKeren Zhou // loop should either be on the immediately preceding iteration, the current 187e66f97e8SKeren Zhou // iteration, or on variables whose values are set before entering the loop. 18845cb4140Sthomasraoux if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), 18945cb4140Sthomasraoux [this](Value operand) { 19045cb4140Sthomasraoux Operation *def = operand.getDefiningOp(); 191e66f97e8SKeren Zhou return !def || 192e66f97e8SKeren Zhou (!stages.contains(def) && forOp->isAncestor(def)); 193c7592c77SNicolas Vasilache })) { 194e66f97e8SKeren Zhou LDBG("--only support loop carried dependency with a distance of 1 or " 195e66f97e8SKeren Zhou "defined outside of the loop -> BAIL"); 196f6f88e66Sthomasraoux return false; 197c7592c77SNicolas Vasilache } 1980736bbd7SThomas Raoux annotateFn = options.annotateFn; 199f6f88e66Sthomasraoux return true; 200f6f88e66Sthomasraoux } 201f6f88e66Sthomasraoux 20256954a53Spawelszczerbuk /// Find operands of all the nested operations within `op`. 20356954a53Spawelszczerbuk static SetVector<Value> getNestedOperands(Operation *op) { 20456954a53Spawelszczerbuk SetVector<Value> operands; 20556954a53Spawelszczerbuk op->walk([&](Operation *nestedOp) { 20656954a53Spawelszczerbuk for (Value operand : nestedOp->getOperands()) { 20756954a53Spawelszczerbuk operands.insert(operand); 20856954a53Spawelszczerbuk } 20956954a53Spawelszczerbuk }); 21056954a53Spawelszczerbuk return operands; 21156954a53Spawelszczerbuk } 21256954a53Spawelszczerbuk 213c933bd81SThomas Raoux /// Compute unrolled cycles of each op (consumer) and verify that each op is 214c933bd81SThomas Raoux /// scheduled after its operands (producers) while adjusting for the distance 215c933bd81SThomas Raoux /// between producer and consumer. 216c933bd81SThomas Raoux bool LoopPipelinerInternal::verifySchedule() { 217c933bd81SThomas Raoux int64_t numCylesPerIter = opOrder.size(); 218c933bd81SThomas Raoux // Pre-compute the unrolled cycle of each op. 219c933bd81SThomas Raoux DenseMap<Operation *, int64_t> unrolledCyles; 220c933bd81SThomas Raoux for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) { 221c933bd81SThomas Raoux Operation *def = opOrder[cycle]; 222c933bd81SThomas Raoux auto it = stages.find(def); 223c933bd81SThomas Raoux assert(it != stages.end()); 224c933bd81SThomas Raoux int64_t stage = it->second; 225c933bd81SThomas Raoux unrolledCyles[def] = cycle + stage * numCylesPerIter; 226c933bd81SThomas Raoux } 227c933bd81SThomas Raoux for (Operation *consumer : opOrder) { 228c933bd81SThomas Raoux int64_t consumerCycle = unrolledCyles[consumer]; 22956954a53Spawelszczerbuk for (Value operand : getNestedOperands(consumer)) { 230c933bd81SThomas Raoux auto [producer, distance] = getDefiningOpAndDistance(operand); 231c933bd81SThomas Raoux if (!producer) 232c933bd81SThomas Raoux continue; 233c933bd81SThomas Raoux auto it = unrolledCyles.find(producer); 234c933bd81SThomas Raoux // Skip producer coming from outside the loop. 235c933bd81SThomas Raoux if (it == unrolledCyles.end()) 236c933bd81SThomas Raoux continue; 237c933bd81SThomas Raoux int64_t producerCycle = it->second; 238c933bd81SThomas Raoux if (consumerCycle < producerCycle - numCylesPerIter * distance) { 239c933bd81SThomas Raoux consumer->emitError("operation scheduled before its operands"); 240c933bd81SThomas Raoux return false; 241c933bd81SThomas Raoux } 242c933bd81SThomas Raoux } 243c933bd81SThomas Raoux } 244c933bd81SThomas Raoux return true; 245c933bd81SThomas Raoux } 246c933bd81SThomas Raoux 247f5fe92f6SChristopher Bate /// Clone `op` and call `callback` on the cloned op's oeprands as well as any 248f5fe92f6SChristopher Bate /// operands of nested ops that: 249f5fe92f6SChristopher Bate /// 1) aren't defined within the new op or 250f5fe92f6SChristopher Bate /// 2) are block arguments. 251f5fe92f6SChristopher Bate static Operation * 252f5fe92f6SChristopher Bate cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, 253f5fe92f6SChristopher Bate function_ref<void(OpOperand *newOperand)> callback) { 254f5fe92f6SChristopher Bate Operation *clone = rewriter.clone(*op); 25556954a53Spawelszczerbuk clone->walk<WalkOrder::PreOrder>([&](Operation *nested) { 25656954a53Spawelszczerbuk // 'clone' itself will be visited first. 257f5fe92f6SChristopher Bate for (OpOperand &operand : nested->getOpOperands()) { 258f5fe92f6SChristopher Bate Operation *def = operand.get().getDefiningOp(); 2595550c821STres Popp if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get())) 260f5fe92f6SChristopher Bate callback(&operand); 261f5fe92f6SChristopher Bate } 262f5fe92f6SChristopher Bate }); 263f5fe92f6SChristopher Bate return clone; 264f5fe92f6SChristopher Bate } 265f5fe92f6SChristopher Bate 26618926666SSJW LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { 2677c900811Spawelszczerbuk // Initialize the iteration argument to the loop initial values. 2683cd2a0bcSMatthias Springer for (auto [arg, operand] : 2693cd2a0bcSMatthias Springer llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { 27045cb4140Sthomasraoux setValueMapping(arg, operand.get(), 0); 27145cb4140Sthomasraoux } 27245cb4140Sthomasraoux auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); 273ef112833SThomas Raoux Location loc = forOp.getLoc(); 274ef112833SThomas Raoux SmallVector<Value> predicates(maxStage); 275f6f88e66Sthomasraoux for (int64_t i = 0; i < maxStage; i++) { 276ef112833SThomas Raoux if (dynamicLoop) { 277ef112833SThomas Raoux Type t = ub.getType(); 278ef112833SThomas Raoux // pred = ub > lb + (i * step) 279ef112833SThomas Raoux Value iv = rewriter.create<arith::AddIOp>( 280ef112833SThomas Raoux loc, lb, 281ef112833SThomas Raoux rewriter.create<arith::MulIOp>( 282ef112833SThomas Raoux loc, step, 283ef112833SThomas Raoux rewriter.create<arith::ConstantOp>( 284ef112833SThomas Raoux loc, rewriter.getIntegerAttr(t, i)))); 285ef112833SThomas Raoux predicates[i] = rewriter.create<arith::CmpIOp>( 286ef112833SThomas Raoux loc, arith::CmpIPredicate::slt, iv, ub); 287ef112833SThomas Raoux } 288ef112833SThomas Raoux 289f6f88e66Sthomasraoux // special handling for induction variable as the increment is implicit. 290ef112833SThomas Raoux // iv = lb + i * step 291ef112833SThomas Raoux Type t = lb.getType(); 292ef112833SThomas Raoux Value iv = rewriter.create<arith::AddIOp>( 293ef112833SThomas Raoux loc, lb, 294ef112833SThomas Raoux rewriter.create<arith::MulIOp>( 295ef112833SThomas Raoux loc, step, 296ef112833SThomas Raoux rewriter.create<arith::ConstantOp>(loc, 297ef112833SThomas Raoux rewriter.getIntegerAttr(t, i)))); 298f6f88e66Sthomasraoux setValueMapping(forOp.getInductionVar(), iv, i); 299f6f88e66Sthomasraoux for (Operation *op : opOrder) { 300f6f88e66Sthomasraoux if (stages[op] > i) 301f6f88e66Sthomasraoux continue; 302f5fe92f6SChristopher Bate Operation *newOp = 303f5fe92f6SChristopher Bate cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { 304f5fe92f6SChristopher Bate auto it = valueMapping.find(newOperand->get()); 305f5fe92f6SChristopher Bate if (it != valueMapping.end()) { 306f5fe92f6SChristopher Bate Value replacement = it->second[i - stages[op]]; 307f5fe92f6SChristopher Bate newOperand->set(replacement); 308f6f88e66Sthomasraoux } 309f5fe92f6SChristopher Bate }); 310ef112833SThomas Raoux int predicateIdx = i - stages[op]; 311ef112833SThomas Raoux if (predicates[predicateIdx]) { 312ebf05993SSJW OpBuilder::InsertionGuard insertGuard(rewriter); 313ef112833SThomas Raoux newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); 31418926666SSJW if (newOp == nullptr) 31518926666SSJW return failure(); 316ef112833SThomas Raoux } 3170736bbd7SThomas Raoux if (annotateFn) 3180736bbd7SThomas Raoux annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); 319f6f88e66Sthomasraoux for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { 3207c900811Spawelszczerbuk Value source = newOp->getResult(destId); 32145cb4140Sthomasraoux // If the value is a loop carried dependency update the loop argument 32245cb4140Sthomasraoux for (OpOperand &operand : yield->getOpOperands()) { 32345cb4140Sthomasraoux if (operand.get() != op->getResult(destId)) 32445cb4140Sthomasraoux continue; 3257c900811Spawelszczerbuk if (predicates[predicateIdx] && 3267c900811Spawelszczerbuk !forOp.getResult(operand.getOperandNumber()).use_empty()) { 3277c900811Spawelszczerbuk // If the value is used outside the loop, we need to make sure we 3287c900811Spawelszczerbuk // return the correct version of it. 3297c900811Spawelszczerbuk Value prevValue = valueMapping 3307c900811Spawelszczerbuk [forOp.getRegionIterArgs()[operand.getOperandNumber()]] 3317c900811Spawelszczerbuk [i - stages[op]]; 3327c900811Spawelszczerbuk source = rewriter.create<arith::SelectOp>( 3337c900811Spawelszczerbuk loc, predicates[predicateIdx], source, prevValue); 33445cb4140Sthomasraoux } 3357c900811Spawelszczerbuk setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], 3367c900811Spawelszczerbuk source, i - stages[op] + 1); 3377c900811Spawelszczerbuk } 3387c900811Spawelszczerbuk setValueMapping(op->getResult(destId), newOp->getResult(destId), 3397c900811Spawelszczerbuk i - stages[op]); 340f6f88e66Sthomasraoux } 341f6f88e66Sthomasraoux } 342f6f88e66Sthomasraoux } 34318926666SSJW return success(); 344f6f88e66Sthomasraoux } 345f6f88e66Sthomasraoux 346f6f88e66Sthomasraoux llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 347f6f88e66Sthomasraoux LoopPipelinerInternal::analyzeCrossStageValues() { 348f6f88e66Sthomasraoux llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues; 349f6f88e66Sthomasraoux for (Operation *op : opOrder) { 350f6f88e66Sthomasraoux unsigned stage = stages[op]; 351f5fe92f6SChristopher Bate 352f5fe92f6SChristopher Bate auto analyzeOperand = [&](OpOperand &operand) { 35319e068b0SThomas Raoux auto [def, distance] = getDefiningOpAndDistance(operand.get()); 354f6f88e66Sthomasraoux if (!def) 355f5fe92f6SChristopher Bate return; 356f6f88e66Sthomasraoux auto defStage = stages.find(def); 35719e068b0SThomas Raoux if (defStage == stages.end() || defStage->second == stage || 35819e068b0SThomas Raoux defStage->second == stage + distance) 359f5fe92f6SChristopher Bate return; 360f6f88e66Sthomasraoux assert(stage > defStage->second); 361f6f88e66Sthomasraoux LiverangeInfo &info = crossStageValues[operand.get()]; 362f6f88e66Sthomasraoux info.defStage = defStage->second; 363f6f88e66Sthomasraoux info.lastUseStage = std::max(info.lastUseStage, stage); 364f5fe92f6SChristopher Bate }; 365f5fe92f6SChristopher Bate 366f5fe92f6SChristopher Bate for (OpOperand &operand : op->getOpOperands()) 367f5fe92f6SChristopher Bate analyzeOperand(operand); 368f5fe92f6SChristopher Bate visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { 369f5fe92f6SChristopher Bate analyzeOperand(*operand); 370f5fe92f6SChristopher Bate }); 371f6f88e66Sthomasraoux } 372f6f88e66Sthomasraoux return crossStageValues; 373f6f88e66Sthomasraoux } 374f6f88e66Sthomasraoux 37519e068b0SThomas Raoux std::pair<Operation *, int64_t> 37619e068b0SThomas Raoux LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { 37719e068b0SThomas Raoux int64_t distance = 0; 37819e068b0SThomas Raoux if (auto arg = dyn_cast<BlockArgument>(value)) { 37919e068b0SThomas Raoux if (arg.getOwner() != forOp.getBody()) 38019e068b0SThomas Raoux return {nullptr, 0}; 38119e068b0SThomas Raoux // Ignore induction variable. 38219e068b0SThomas Raoux if (arg.getArgNumber() == 0) 38319e068b0SThomas Raoux return {nullptr, 0}; 38419e068b0SThomas Raoux distance++; 38519e068b0SThomas Raoux value = 38619e068b0SThomas Raoux forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); 38719e068b0SThomas Raoux } 38819e068b0SThomas Raoux Operation *def = value.getDefiningOp(); 38919e068b0SThomas Raoux if (!def) 39019e068b0SThomas Raoux return {nullptr, 0}; 39119e068b0SThomas Raoux return {def, distance}; 39219e068b0SThomas Raoux } 39319e068b0SThomas Raoux 394f6f88e66Sthomasraoux scf::ForOp LoopPipelinerInternal::createKernelLoop( 395f6f88e66Sthomasraoux const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 396f6f88e66Sthomasraoux &crossStageValues, 3971cff4cbdSNicolas Vasilache RewriterBase &rewriter, 398f6f88e66Sthomasraoux llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) { 399f6f88e66Sthomasraoux // Creates the list of initial values associated to values used across 400f6f88e66Sthomasraoux // stages. The initial values come from the prologue created above. 401f6f88e66Sthomasraoux // Keep track of the kernel argument associated to each version of the 402f6f88e66Sthomasraoux // values passed to the kernel. 40345cb4140Sthomasraoux llvm::SmallVector<Value> newLoopArg; 40445cb4140Sthomasraoux // For existing loop argument initialize them with the right version from the 40545cb4140Sthomasraoux // prologue. 406e4853be2SMehdi Amini for (const auto &retVal : 40745cb4140Sthomasraoux llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { 40845cb4140Sthomasraoux Operation *def = retVal.value().getDefiningOp(); 409e66f97e8SKeren Zhou assert(def && "Only support loop carried dependencies of distance of 1 or " 410e66f97e8SKeren Zhou "outside the loop"); 411e66f97e8SKeren Zhou auto defStage = stages.find(def); 412e66f97e8SKeren Zhou if (defStage != stages.end()) { 413e66f97e8SKeren Zhou Value valueVersion = 414e66f97e8SKeren Zhou valueMapping[forOp.getRegionIterArgs()[retVal.index()]] 415e66f97e8SKeren Zhou [maxStage - defStage->second]; 41645cb4140Sthomasraoux assert(valueVersion); 41745cb4140Sthomasraoux newLoopArg.push_back(valueVersion); 418e66f97e8SKeren Zhou } else 419e66f97e8SKeren Zhou newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]); 42045cb4140Sthomasraoux } 421f6f88e66Sthomasraoux for (auto escape : crossStageValues) { 422f6f88e66Sthomasraoux LiverangeInfo &info = escape.second; 423f6f88e66Sthomasraoux Value value = escape.first; 424f6f88e66Sthomasraoux for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; 425f6f88e66Sthomasraoux stageIdx++) { 426f6f88e66Sthomasraoux Value valueVersion = 427f6f88e66Sthomasraoux valueMapping[value][maxStage - info.lastUseStage + stageIdx]; 428f6f88e66Sthomasraoux assert(valueVersion); 429f6f88e66Sthomasraoux newLoopArg.push_back(valueVersion); 430f6f88e66Sthomasraoux loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - 431f6f88e66Sthomasraoux stageIdx)] = newLoopArg.size() - 1; 432f6f88e66Sthomasraoux } 433f6f88e66Sthomasraoux } 434f6f88e66Sthomasraoux 435205c08b5SThomas Raoux // Create the new kernel loop. When we peel the epilgue we need to peel 436205c08b5SThomas Raoux // `numStages - 1` iterations. Then we adjust the upper bound to remove those 437205c08b5SThomas Raoux // iterations. 438205c08b5SThomas Raoux Value newUb = forOp.getUpperBound(); 439ef112833SThomas Raoux if (peelEpilogue) { 440ef112833SThomas Raoux Type t = ub.getType(); 441ef112833SThomas Raoux Location loc = forOp.getLoc(); 442ef112833SThomas Raoux // newUb = ub - maxStage * step 443ef112833SThomas Raoux Value maxStageValue = rewriter.create<arith::ConstantOp>( 444ef112833SThomas Raoux loc, rewriter.getIntegerAttr(t, maxStage)); 445ef112833SThomas Raoux Value maxStageByStep = 446ef112833SThomas Raoux rewriter.create<arith::MulIOp>(loc, step, maxStageValue); 447ef112833SThomas Raoux newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep); 448ef112833SThomas Raoux } 449c0342a2dSJacques Pienaar auto newForOp = 450c0342a2dSJacques Pienaar rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb, 451c0342a2dSJacques Pienaar forOp.getStep(), newLoopArg); 452f5fe92f6SChristopher Bate // When there are no iter args, the loop body terminator will be created. 453f5fe92f6SChristopher Bate // Since we always create it below, remove the terminator if it was created. 454f5fe92f6SChristopher Bate if (!newForOp.getBody()->empty()) 455f5fe92f6SChristopher Bate rewriter.eraseOp(newForOp.getBody()->getTerminator()); 456f6f88e66Sthomasraoux return newForOp; 457f6f88e66Sthomasraoux } 458f6f88e66Sthomasraoux 459371366ceSAlex Zinenko LogicalResult LoopPipelinerInternal::createKernel( 460f6f88e66Sthomasraoux scf::ForOp newForOp, 461f6f88e66Sthomasraoux const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 462f6f88e66Sthomasraoux &crossStageValues, 463f6f88e66Sthomasraoux const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap, 4641cff4cbdSNicolas Vasilache RewriterBase &rewriter) { 465f6f88e66Sthomasraoux valueMapping.clear(); 466f6f88e66Sthomasraoux 467f6f88e66Sthomasraoux // Create the kernel, we clone instruction based on the order given by 468f6f88e66Sthomasraoux // user and remap operands coming from a previous stages. 469f6f88e66Sthomasraoux rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); 4704d67b278SJeff Niu IRMapping mapping; 471f6f88e66Sthomasraoux mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); 472e4853be2SMehdi Amini for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { 47345cb4140Sthomasraoux mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); 47445cb4140Sthomasraoux } 475205c08b5SThomas Raoux SmallVector<Value> predicates(maxStage + 1, nullptr); 476205c08b5SThomas Raoux if (!peelEpilogue) { 477205c08b5SThomas Raoux // Create a predicate for each stage except the last stage. 478ef112833SThomas Raoux Location loc = newForOp.getLoc(); 479ef112833SThomas Raoux Type t = ub.getType(); 480205c08b5SThomas Raoux for (unsigned i = 0; i < maxStage; i++) { 481ef112833SThomas Raoux // c = ub - (maxStage - i) * step 482ef112833SThomas Raoux Value c = rewriter.create<arith::SubIOp>( 483ef112833SThomas Raoux loc, ub, 484ef112833SThomas Raoux rewriter.create<arith::MulIOp>( 485ef112833SThomas Raoux loc, step, 486ef112833SThomas Raoux rewriter.create<arith::ConstantOp>( 487ef112833SThomas Raoux loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i))))); 488ef112833SThomas Raoux 489205c08b5SThomas Raoux Value pred = rewriter.create<arith::CmpIOp>( 490205c08b5SThomas Raoux newForOp.getLoc(), arith::CmpIPredicate::slt, 491205c08b5SThomas Raoux newForOp.getInductionVar(), c); 492205c08b5SThomas Raoux predicates[i] = pred; 493205c08b5SThomas Raoux } 494205c08b5SThomas Raoux } 495f6f88e66Sthomasraoux for (Operation *op : opOrder) { 496f6f88e66Sthomasraoux int64_t useStage = stages[op]; 497f6f88e66Sthomasraoux auto *newOp = rewriter.clone(*op, mapping); 498117db47dSThomas Raoux SmallVector<OpOperand *> operands; 499117db47dSThomas Raoux // Collect all the operands for the cloned op and its nested ops. 500117db47dSThomas Raoux op->walk([&operands](Operation *nestedOp) { 501117db47dSThomas Raoux for (OpOperand &operand : nestedOp->getOpOperands()) { 502117db47dSThomas Raoux operands.push_back(&operand); 503117db47dSThomas Raoux } 504117db47dSThomas Raoux }); 505117db47dSThomas Raoux for (OpOperand *operand : operands) { 506117db47dSThomas Raoux Operation *nestedNewOp = mapping.lookup(operand->getOwner()); 507117db47dSThomas Raoux // Special case for the induction variable uses. We replace it with a 508117db47dSThomas Raoux // version incremented based on the stage where it is used. 509117db47dSThomas Raoux if (operand->get() == forOp.getInductionVar()) { 510117db47dSThomas Raoux rewriter.setInsertionPoint(newOp); 511ef112833SThomas Raoux 512ef112833SThomas Raoux // offset = (maxStage - stages[op]) * step 513ef112833SThomas Raoux Type t = step.getType(); 514ef112833SThomas Raoux Value offset = rewriter.create<arith::MulIOp>( 515ef112833SThomas Raoux forOp.getLoc(), step, 516ef112833SThomas Raoux rewriter.create<arith::ConstantOp>( 517ef112833SThomas Raoux forOp.getLoc(), 518ef112833SThomas Raoux rewriter.getIntegerAttr(t, maxStage - stages[op]))); 519117db47dSThomas Raoux Value iv = rewriter.create<arith::AddIOp>( 520117db47dSThomas Raoux forOp.getLoc(), newForOp.getInductionVar(), offset); 521117db47dSThomas Raoux nestedNewOp->setOperand(operand->getOperandNumber(), iv); 522117db47dSThomas Raoux rewriter.setInsertionPointAfter(newOp); 523117db47dSThomas Raoux continue; 524117db47dSThomas Raoux } 52519e068b0SThomas Raoux Value source = operand->get(); 52619e068b0SThomas Raoux auto arg = dyn_cast<BlockArgument>(source); 527117db47dSThomas Raoux if (arg && arg.getOwner() == forOp.getBody()) { 528117db47dSThomas Raoux Value ret = forOp.getBody()->getTerminator()->getOperand( 529117db47dSThomas Raoux arg.getArgNumber() - 1); 530117db47dSThomas Raoux Operation *dep = ret.getDefiningOp(); 531117db47dSThomas Raoux if (!dep) 532117db47dSThomas Raoux continue; 533117db47dSThomas Raoux auto stageDep = stages.find(dep); 534117db47dSThomas Raoux if (stageDep == stages.end() || stageDep->second == useStage) 535117db47dSThomas Raoux continue; 53619e068b0SThomas Raoux // If the value is a loop carried value coming from stage N + 1 remap, 53719e068b0SThomas Raoux // it will become a direct use. 53819e068b0SThomas Raoux if (stageDep->second == useStage + 1) { 539117db47dSThomas Raoux nestedNewOp->setOperand(operand->getOperandNumber(), 540117db47dSThomas Raoux mapping.lookupOrDefault(ret)); 541117db47dSThomas Raoux continue; 542117db47dSThomas Raoux } 54319e068b0SThomas Raoux source = ret; 54419e068b0SThomas Raoux } 545117db47dSThomas Raoux // For operands defined in a previous stage we need to remap it to use 546117db47dSThomas Raoux // the correct region argument. We look for the right version of the 547117db47dSThomas Raoux // Value based on the stage where it is used. 54819e068b0SThomas Raoux Operation *def = source.getDefiningOp(); 549117db47dSThomas Raoux if (!def) 550117db47dSThomas Raoux continue; 551117db47dSThomas Raoux auto stageDef = stages.find(def); 552117db47dSThomas Raoux if (stageDef == stages.end() || stageDef->second == useStage) 553117db47dSThomas Raoux continue; 554117db47dSThomas Raoux auto remap = loopArgMap.find( 555117db47dSThomas Raoux std::make_pair(operand->get(), useStage - stageDef->second)); 556117db47dSThomas Raoux assert(remap != loopArgMap.end()); 557117db47dSThomas Raoux nestedNewOp->setOperand(operand->getOperandNumber(), 558117db47dSThomas Raoux newForOp.getRegionIterArgs()[remap->second]); 559117db47dSThomas Raoux } 560f5fe92f6SChristopher Bate 561205c08b5SThomas Raoux if (predicates[useStage]) { 562ebf05993SSJW OpBuilder::InsertionGuard insertGuard(rewriter); 5631cff4cbdSNicolas Vasilache newOp = predicateFn(rewriter, newOp, predicates[useStage]); 564371366ceSAlex Zinenko if (!newOp) 565371366ceSAlex Zinenko return failure(); 566205c08b5SThomas Raoux // Remap the results to the new predicated one. 567205c08b5SThomas Raoux for (auto values : llvm::zip(op->getResults(), newOp->getResults())) 568205c08b5SThomas Raoux mapping.map(std::get<0>(values), std::get<1>(values)); 569205c08b5SThomas Raoux } 5700736bbd7SThomas Raoux if (annotateFn) 5710736bbd7SThomas Raoux annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0); 572f6f88e66Sthomasraoux } 573f6f88e66Sthomasraoux 574f6f88e66Sthomasraoux // Collect the Values that need to be returned by the forOp. For each 575f6f88e66Sthomasraoux // value we need to have `LastUseStage - DefStage` number of versions 576f6f88e66Sthomasraoux // returned. 577f6f88e66Sthomasraoux // We create a mapping between original values and the associated loop 578f6f88e66Sthomasraoux // returned values that will be needed by the epilogue. 579f6f88e66Sthomasraoux llvm::SmallVector<Value> yieldOperands; 58019e068b0SThomas Raoux for (OpOperand &yieldOperand : 58119e068b0SThomas Raoux forOp.getBody()->getTerminator()->getOpOperands()) { 58219e068b0SThomas Raoux Value source = mapping.lookupOrDefault(yieldOperand.get()); 58319e068b0SThomas Raoux // When we don't peel the epilogue and the yield value is used outside the 58419e068b0SThomas Raoux // loop we need to make sure we return the version from numStages - 58519e068b0SThomas Raoux // defStage. 58619e068b0SThomas Raoux if (!peelEpilogue && 58719e068b0SThomas Raoux !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { 58819e068b0SThomas Raoux Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first; 58919e068b0SThomas Raoux if (def) { 59019e068b0SThomas Raoux auto defStage = stages.find(def); 59119e068b0SThomas Raoux if (defStage != stages.end() && defStage->second < maxStage) { 59219e068b0SThomas Raoux Value pred = predicates[defStage->second]; 59319e068b0SThomas Raoux source = rewriter.create<arith::SelectOp>( 59419e068b0SThomas Raoux pred.getLoc(), pred, source, 59519e068b0SThomas Raoux newForOp.getBody() 59619e068b0SThomas Raoux ->getArguments()[yieldOperand.getOperandNumber() + 1]); 59745cb4140Sthomasraoux } 59819e068b0SThomas Raoux } 59919e068b0SThomas Raoux } 60019e068b0SThomas Raoux yieldOperands.push_back(source); 60119e068b0SThomas Raoux } 60219e068b0SThomas Raoux 603f6f88e66Sthomasraoux for (auto &it : crossStageValues) { 604f6f88e66Sthomasraoux int64_t version = maxStage - it.second.lastUseStage + 1; 605f6f88e66Sthomasraoux unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; 606371366ceSAlex Zinenko // add the original version to yield ops. 607371366ceSAlex Zinenko // If there is a live range spanning across more than 2 stages we need to 608371366ceSAlex Zinenko // add extra arg. 609f6f88e66Sthomasraoux for (unsigned i = 1; i < numVersionReturned; i++) { 610f6f88e66Sthomasraoux setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), 611f6f88e66Sthomasraoux version++); 612f6f88e66Sthomasraoux yieldOperands.push_back( 613f6f88e66Sthomasraoux newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + 614f6f88e66Sthomasraoux newForOp.getNumInductionVars()]); 615f6f88e66Sthomasraoux } 616f6f88e66Sthomasraoux setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), 617f6f88e66Sthomasraoux version++); 618f6f88e66Sthomasraoux yieldOperands.push_back(mapping.lookupOrDefault(it.first)); 619f6f88e66Sthomasraoux } 62045cb4140Sthomasraoux // Map the yield operand to the forOp returned value. 621e4853be2SMehdi Amini for (const auto &retVal : 62245cb4140Sthomasraoux llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { 62345cb4140Sthomasraoux Operation *def = retVal.value().getDefiningOp(); 624e66f97e8SKeren Zhou assert(def && "Only support loop carried dependencies of distance of 1 or " 625e66f97e8SKeren Zhou "defined outside the loop"); 626e66f97e8SKeren Zhou auto defStage = stages.find(def); 627e66f97e8SKeren Zhou if (defStage == stages.end()) { 628e66f97e8SKeren Zhou for (unsigned int stage = 1; stage <= maxStage; stage++) 629e66f97e8SKeren Zhou setValueMapping(forOp.getRegionIterArgs()[retVal.index()], 630e66f97e8SKeren Zhou retVal.value(), stage); 631e66f97e8SKeren Zhou } else if (defStage->second > 0) { 63245cb4140Sthomasraoux setValueMapping(forOp.getRegionIterArgs()[retVal.index()], 63345cb4140Sthomasraoux newForOp->getResult(retVal.index()), 634e66f97e8SKeren Zhou maxStage - defStage->second + 1); 63545cb4140Sthomasraoux } 63619e068b0SThomas Raoux } 637f6f88e66Sthomasraoux rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands); 638371366ceSAlex Zinenko return success(); 639f6f88e66Sthomasraoux } 640f6f88e66Sthomasraoux 641ebf05993SSJW LogicalResult 642ebf05993SSJW LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, 643ebf05993SSJW llvm::SmallVector<Value> &returnValues) { 644ebf05993SSJW Location loc = forOp.getLoc(); 645*8da5aa16SSJW Type t = lb.getType(); 646*8da5aa16SSJW 647f6f88e66Sthomasraoux // Emit different versions of the induction variable. They will be 648f6f88e66Sthomasraoux // removed by dead code if not used. 649ebf05993SSJW 650*8da5aa16SSJW auto createConst = [&](int v) { 651*8da5aa16SSJW return rewriter.create<arith::ConstantOp>(loc, 652*8da5aa16SSJW rewriter.getIntegerAttr(t, v)); 653*8da5aa16SSJW }; 654*8da5aa16SSJW 655*8da5aa16SSJW // total_iterations = cdiv(range_diff, step); 656*8da5aa16SSJW // - range_diff = ub - lb 657*8da5aa16SSJW // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step 658*8da5aa16SSJW Value zero = createConst(0); 659*8da5aa16SSJW Value one = createConst(1); 6607645d9c7SSJW Value stepLessZero = rewriter.create<arith::CmpIOp>( 6617645d9c7SSJW loc, arith::CmpIPredicate::slt, step, zero); 6627645d9c7SSJW Value stepDecr = 663*8da5aa16SSJW rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1)); 6647645d9c7SSJW 6657645d9c7SSJW Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb); 6667645d9c7SSJW Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step); 6677645d9c7SSJW Value rangeDecr = 6687645d9c7SSJW rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr); 6697645d9c7SSJW Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step); 670fa089b01SSJW 671*8da5aa16SSJW // If total_iters < max_stage, start the epilogue at zero to match the 672*8da5aa16SSJW // ramp-up in the prologue. 673*8da5aa16SSJW // start_iter = max(0, total_iters - max_stage) 674*8da5aa16SSJW Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations, 675*8da5aa16SSJW createConst(maxStage)); 676*8da5aa16SSJW iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI); 677*8da5aa16SSJW 678*8da5aa16SSJW // Capture predicates for dynamic loops. 679ebf05993SSJW SmallVector<Value> predicates(maxStage + 1); 680*8da5aa16SSJW 681*8da5aa16SSJW for (int64_t i = 1; i <= maxStage; i++) { 682ebf05993SSJW // newLastIter = lb + step * iterI 683ef112833SThomas Raoux Value newlastIter = rewriter.create<arith::AddIOp>( 684ebf05993SSJW loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI)); 685ebf05993SSJW 686*8da5aa16SSJW setValueMapping(forOp.getInductionVar(), newlastIter, i); 687*8da5aa16SSJW 688*8da5aa16SSJW // increment to next iterI 689*8da5aa16SSJW iterI = rewriter.create<arith::AddIOp>(loc, iterI, one); 690ebf05993SSJW 691ebf05993SSJW if (dynamicLoop) { 692*8da5aa16SSJW // Disable stages when `i` is greater than total_iters. 693*8da5aa16SSJW // pred = total_iters >= i 694*8da5aa16SSJW predicates[i] = rewriter.create<arith::CmpIOp>( 695*8da5aa16SSJW loc, arith::CmpIPredicate::sge, totalIterations, createConst(i)); 696f6f88e66Sthomasraoux } 697ebf05993SSJW } 698ebf05993SSJW 699f5fe92f6SChristopher Bate // Emit `maxStage - 1` epilogue part that includes operations from stages 700f6f88e66Sthomasraoux // [i; maxStage]. 701f6f88e66Sthomasraoux for (int64_t i = 1; i <= maxStage; i++) { 702ebf05993SSJW SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size()); 703f6f88e66Sthomasraoux for (Operation *op : opOrder) { 704f6f88e66Sthomasraoux if (stages[op] < i) 705f6f88e66Sthomasraoux continue; 706ebf05993SSJW unsigned currentVersion = maxStage - stages[op] + i; 707ebf05993SSJW unsigned nextVersion = currentVersion + 1; 708f5fe92f6SChristopher Bate Operation *newOp = 709f5fe92f6SChristopher Bate cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { 710f5fe92f6SChristopher Bate auto it = valueMapping.find(newOperand->get()); 711f6f88e66Sthomasraoux if (it != valueMapping.end()) { 712ebf05993SSJW Value replacement = it->second[currentVersion]; 713f5fe92f6SChristopher Bate newOperand->set(replacement); 714f6f88e66Sthomasraoux } 715f5fe92f6SChristopher Bate }); 716ebf05993SSJW if (dynamicLoop) { 717ebf05993SSJW OpBuilder::InsertionGuard insertGuard(rewriter); 718ebf05993SSJW newOp = predicateFn(rewriter, newOp, predicates[currentVersion]); 719ebf05993SSJW if (!newOp) 720ebf05993SSJW return failure(); 721ebf05993SSJW } 7220736bbd7SThomas Raoux if (annotateFn) 7230736bbd7SThomas Raoux annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1); 724ebf05993SSJW 725ebf05993SSJW for (auto [opRes, newRes] : 726ebf05993SSJW llvm::zip(op->getResults(), newOp->getResults())) { 727ebf05993SSJW setValueMapping(opRes, newRes, currentVersion); 72845cb4140Sthomasraoux // If the value is a loop carried dependency update the loop argument 72945cb4140Sthomasraoux // mapping and keep track of the last version to replace the original 73045cb4140Sthomasraoux // forOp uses. 73145cb4140Sthomasraoux for (OpOperand &operand : 73245cb4140Sthomasraoux forOp.getBody()->getTerminator()->getOpOperands()) { 733ebf05993SSJW if (operand.get() != opRes) 73445cb4140Sthomasraoux continue; 73545cb4140Sthomasraoux // If the version is greater than maxStage it means it maps to the 73645cb4140Sthomasraoux // original forOp returned value. 737ebf05993SSJW unsigned ri = operand.getOperandNumber(); 738ebf05993SSJW returnValues[ri] = newRes; 739ebf05993SSJW Value mapVal = forOp.getRegionIterArgs()[ri]; 740ebf05993SSJW returnMap[ri] = std::make_pair(mapVal, currentVersion); 741ebf05993SSJW if (nextVersion <= maxStage) 742ebf05993SSJW setValueMapping(mapVal, newRes, nextVersion); 74345cb4140Sthomasraoux } 744ebf05993SSJW } 745ebf05993SSJW } 746ebf05993SSJW if (dynamicLoop) { 747ebf05993SSJW // Select return values from this stage (live outs) based on predication. 748ebf05993SSJW // If the stage is valid select the peeled value, else use previous stage 749ebf05993SSJW // value. 750ebf05993SSJW for (auto pair : llvm::enumerate(returnValues)) { 751ebf05993SSJW unsigned ri = pair.index(); 752ebf05993SSJW auto [mapVal, currentVersion] = returnMap[ri]; 753ebf05993SSJW if (mapVal) { 754ebf05993SSJW unsigned nextVersion = currentVersion + 1; 755ebf05993SSJW Value pred = predicates[currentVersion]; 756ebf05993SSJW Value prevValue = valueMapping[mapVal][currentVersion]; 757ebf05993SSJW auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(), 758ebf05993SSJW prevValue); 759ebf05993SSJW returnValues[ri] = selOp; 760ebf05993SSJW if (nextVersion <= maxStage) 761ebf05993SSJW setValueMapping(mapVal, selOp, nextVersion); 762f6f88e66Sthomasraoux } 763f6f88e66Sthomasraoux } 764f6f88e66Sthomasraoux } 765f6f88e66Sthomasraoux } 766ebf05993SSJW return success(); 76745cb4140Sthomasraoux } 768f6f88e66Sthomasraoux 769f6f88e66Sthomasraoux void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { 770f6f88e66Sthomasraoux auto it = valueMapping.find(key); 771f6f88e66Sthomasraoux // If the value is not in the map yet add a vector big enough to store all 772f6f88e66Sthomasraoux // versions. 773f6f88e66Sthomasraoux if (it == valueMapping.end()) 774f6f88e66Sthomasraoux it = 775f6f88e66Sthomasraoux valueMapping 776f6f88e66Sthomasraoux .insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1))) 777f6f88e66Sthomasraoux .first; 778f6f88e66Sthomasraoux it->second[idx] = el; 779f6f88e66Sthomasraoux } 780f6f88e66Sthomasraoux 7815f0d4f20SAlex Zinenko } // namespace 7825f0d4f20SAlex Zinenko 7831cff4cbdSNicolas Vasilache FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, 784371366ceSAlex Zinenko const PipeliningOption &options, 785371366ceSAlex Zinenko bool *modifiedIR) { 786371366ceSAlex Zinenko if (modifiedIR) 787371366ceSAlex Zinenko *modifiedIR = false; 788f6f88e66Sthomasraoux LoopPipelinerInternal pipeliner; 789f6f88e66Sthomasraoux if (!pipeliner.initializeLoopInfo(forOp, options)) 790f6f88e66Sthomasraoux return failure(); 791f6f88e66Sthomasraoux 792371366ceSAlex Zinenko if (modifiedIR) 793371366ceSAlex Zinenko *modifiedIR = true; 794371366ceSAlex Zinenko 795f6f88e66Sthomasraoux // 1. Emit prologue. 79618926666SSJW if (failed(pipeliner.emitPrologue(rewriter))) 79718926666SSJW return failure(); 798f6f88e66Sthomasraoux 799f6f88e66Sthomasraoux // 2. Track values used across stages. When a value cross stages it will 800f6f88e66Sthomasraoux // need to be passed as loop iteration arguments. 801f6f88e66Sthomasraoux // We first collect the values that are used in a different stage than where 802f6f88e66Sthomasraoux // they are defined. 803f6f88e66Sthomasraoux llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 804f6f88e66Sthomasraoux crossStageValues = pipeliner.analyzeCrossStageValues(); 805f6f88e66Sthomasraoux 806f6f88e66Sthomasraoux // Mapping between original loop values used cross stage and the block 807f6f88e66Sthomasraoux // arguments associated after pipelining. A Value may map to several 808f6f88e66Sthomasraoux // arguments if its liverange spans across more than 2 stages. 809f6f88e66Sthomasraoux llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap; 810f6f88e66Sthomasraoux // 3. Create the new kernel loop and return the block arguments mapping. 811f6f88e66Sthomasraoux ForOp newForOp = 812f6f88e66Sthomasraoux pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); 813f6f88e66Sthomasraoux // Create the kernel block, order ops based on user choice and remap 814f6f88e66Sthomasraoux // operands. 815371366ceSAlex Zinenko if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, 816371366ceSAlex Zinenko rewriter))) 817371366ceSAlex Zinenko return failure(); 818f6f88e66Sthomasraoux 819205c08b5SThomas Raoux llvm::SmallVector<Value> returnValues = 820205c08b5SThomas Raoux newForOp.getResults().take_front(forOp->getNumResults()); 821205c08b5SThomas Raoux if (options.peelEpilogue) { 822f6f88e66Sthomasraoux // 4. Emit the epilogue after the new forOp. 823f6f88e66Sthomasraoux rewriter.setInsertionPointAfter(newForOp); 824ebf05993SSJW if (failed(pipeliner.emitEpilogue(rewriter, returnValues))) 825ebf05993SSJW return failure(); 826205c08b5SThomas Raoux } 827f6f88e66Sthomasraoux // 5. Erase the original loop and replace the uses with the epilogue output. 828f6f88e66Sthomasraoux if (forOp->getNumResults() > 0) 82945cb4140Sthomasraoux rewriter.replaceOp(forOp, returnValues); 830f6f88e66Sthomasraoux else 831f6f88e66Sthomasraoux rewriter.eraseOp(forOp); 832f6f88e66Sthomasraoux 8335f0d4f20SAlex Zinenko return newForOp; 834f6f88e66Sthomasraoux } 835f6f88e66Sthomasraoux 836f6f88e66Sthomasraoux void mlir::scf::populateSCFLoopPipeliningPatterns( 837f6f88e66Sthomasraoux RewritePatternSet &patterns, const PipeliningOption &options) { 8385f0d4f20SAlex Zinenko patterns.add<ForLoopPipeliningPattern>(options, patterns.getContext()); 839f6f88e66Sthomasraoux } 840