xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (revision 8da5aa16f65bc297663573bacd3030f975b9fcde)
1f6f88e66Sthomasraoux //===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
2f6f88e66Sthomasraoux //
3f6f88e66Sthomasraoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f6f88e66Sthomasraoux // See https://llvm.org/LICENSE.txt for license information.
5f6f88e66Sthomasraoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f6f88e66Sthomasraoux //
7f6f88e66Sthomasraoux //===----------------------------------------------------------------------===//
8f6f88e66Sthomasraoux //
9f6f88e66Sthomasraoux // This file implements loop software pipelining
10f6f88e66Sthomasraoux //
11f6f88e66Sthomasraoux //===----------------------------------------------------------------------===//
12f6f88e66Sthomasraoux 
13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
148b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
158b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h"
168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
17f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
184d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
19f6f88e66Sthomasraoux #include "mlir/IR/PatternMatch.h"
20f5fe92f6SChristopher Bate #include "mlir/Transforms/RegionUtils.h"
2167d0d7acSMichele Scuttari #include "llvm/ADT/MapVector.h"
22c7592c77SNicolas Vasilache #include "llvm/Support/Debug.h"
230fb216fbSRamkumar Ramachandra #include "llvm/Support/MathExtras.h"
24c7592c77SNicolas Vasilache 
25c7592c77SNicolas Vasilache #define DEBUG_TYPE "scf-loop-pipelining"
26c7592c77SNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27c7592c77SNicolas Vasilache #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
28f6f88e66Sthomasraoux 
29f6f88e66Sthomasraoux using namespace mlir;
30f6f88e66Sthomasraoux using namespace mlir::scf;
31f6f88e66Sthomasraoux 
32f6f88e66Sthomasraoux namespace {
33f6f88e66Sthomasraoux 
34f6f88e66Sthomasraoux /// Helper to keep internal information during pipelining transformation.
35f6f88e66Sthomasraoux struct LoopPipelinerInternal {
36f6f88e66Sthomasraoux   /// Coarse liverange information for ops used across stages.
37f6f88e66Sthomasraoux   struct LiverangeInfo {
38f6f88e66Sthomasraoux     unsigned lastUseStage = 0;
39f6f88e66Sthomasraoux     unsigned defStage = 0;
40f6f88e66Sthomasraoux   };
41f6f88e66Sthomasraoux 
42f6f88e66Sthomasraoux protected:
43f6f88e66Sthomasraoux   ForOp forOp;
44f6f88e66Sthomasraoux   unsigned maxStage = 0;
45f6f88e66Sthomasraoux   DenseMap<Operation *, unsigned> stages;
46f6f88e66Sthomasraoux   std::vector<Operation *> opOrder;
47ef112833SThomas Raoux   Value ub;
48ef112833SThomas Raoux   Value lb;
49ef112833SThomas Raoux   Value step;
50ef112833SThomas Raoux   bool dynamicLoop;
510736bbd7SThomas Raoux   PipeliningOption::AnnotationlFnType annotateFn = nullptr;
52205c08b5SThomas Raoux   bool peelEpilogue;
53205c08b5SThomas Raoux   PipeliningOption::PredicateOpFn predicateFn = nullptr;
54f6f88e66Sthomasraoux 
55f6f88e66Sthomasraoux   // When peeling the kernel we generate several version of each value for
56f6f88e66Sthomasraoux   // different stage of the prologue. This map tracks the mapping between
57f6f88e66Sthomasraoux   // original Values in the loop and the different versions
58f6f88e66Sthomasraoux   // peeled from the loop.
59f6f88e66Sthomasraoux   DenseMap<Value, llvm::SmallVector<Value>> valueMapping;
60f6f88e66Sthomasraoux 
61f6f88e66Sthomasraoux   /// Assign a value to `valueMapping`, this means `val` represents the version
62f6f88e66Sthomasraoux   /// `idx` of `key` in the epilogue.
63f6f88e66Sthomasraoux   void setValueMapping(Value key, Value el, int64_t idx);
64f6f88e66Sthomasraoux 
6519e068b0SThomas Raoux   /// Return the defining op of the given value, if the Value is an argument of
6619e068b0SThomas Raoux   /// the loop return the associated defining op in the loop and its distance to
6719e068b0SThomas Raoux   /// the Value.
6819e068b0SThomas Raoux   std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
6919e068b0SThomas Raoux 
70c933bd81SThomas Raoux   /// Return true if the schedule is possible and return false otherwise. A
71c933bd81SThomas Raoux   /// schedule is correct if all definitions are scheduled before uses.
72c933bd81SThomas Raoux   bool verifySchedule();
73c933bd81SThomas Raoux 
74f6f88e66Sthomasraoux public:
75f6f88e66Sthomasraoux   /// Initalize the information for the given `op`, return true if it
76f6f88e66Sthomasraoux   /// satisfies the pre-condition to apply pipelining.
77f6f88e66Sthomasraoux   bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
78f6f88e66Sthomasraoux   /// Emits the prologue, this creates `maxStage - 1` part which will contain
79f6f88e66Sthomasraoux   /// operations from stages [0; i], where i is the part index.
8018926666SSJW   LogicalResult emitPrologue(RewriterBase &rewriter);
81f6f88e66Sthomasraoux   /// Gather liverange information for Values that are used in a different stage
82f6f88e66Sthomasraoux   /// than its definition.
83f6f88e66Sthomasraoux   llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
84f6f88e66Sthomasraoux   scf::ForOp createKernelLoop(
85f6f88e66Sthomasraoux       const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
861cff4cbdSNicolas Vasilache       RewriterBase &rewriter,
87f6f88e66Sthomasraoux       llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
88f6f88e66Sthomasraoux   /// Emits the pipelined kernel. This clones loop operations following user
89f6f88e66Sthomasraoux   /// order and remaps operands defined in a different stage as their use.
90371366ceSAlex Zinenko   LogicalResult createKernel(
91f6f88e66Sthomasraoux       scf::ForOp newForOp,
92f6f88e66Sthomasraoux       const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
93f6f88e66Sthomasraoux       const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
941cff4cbdSNicolas Vasilache       RewriterBase &rewriter);
95f6f88e66Sthomasraoux   /// Emits the epilogue, this creates `maxStage - 1` part which will contain
96f6f88e66Sthomasraoux   /// operations from stages [i; maxStage], where i is the part index.
97ebf05993SSJW   LogicalResult emitEpilogue(RewriterBase &rewriter,
98e66f97e8SKeren Zhou                              llvm::SmallVector<Value> &returnValues);
99f6f88e66Sthomasraoux };
100f6f88e66Sthomasraoux 
101f6f88e66Sthomasraoux bool LoopPipelinerInternal::initializeLoopInfo(
102f6f88e66Sthomasraoux     ForOp op, const PipeliningOption &options) {
103c7592c77SNicolas Vasilache   LDBG("Start initializeLoopInfo");
104f6f88e66Sthomasraoux   forOp = op;
105ef112833SThomas Raoux   ub = forOp.getUpperBound();
106ef112833SThomas Raoux   lb = forOp.getLowerBound();
107ef112833SThomas Raoux   step = forOp.getStep();
108ef112833SThomas Raoux 
109ef112833SThomas Raoux   dynamicLoop = true;
110ef112833SThomas Raoux   auto upperBoundCst = getConstantIntValue(ub);
111ef112833SThomas Raoux   auto lowerBoundCst = getConstantIntValue(lb);
112ef112833SThomas Raoux   auto stepCst = getConstantIntValue(step);
113c7592c77SNicolas Vasilache   if (!upperBoundCst || !lowerBoundCst || !stepCst) {
114ef112833SThomas Raoux     if (!options.supportDynamicLoops) {
115ef112833SThomas Raoux       LDBG("--dynamic loop not supported -> BAIL");
116f6f88e66Sthomasraoux       return false;
117c7592c77SNicolas Vasilache     }
118ef112833SThomas Raoux   } else {
119ef112833SThomas Raoux     int64_t ubImm = upperBoundCst.value();
120ef112833SThomas Raoux     int64_t lbImm = lowerBoundCst.value();
121ef112833SThomas Raoux     int64_t stepImm = stepCst.value();
1220fb216fbSRamkumar Ramachandra     int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
123ef112833SThomas Raoux     if (numIteration > maxStage) {
124ef112833SThomas Raoux       dynamicLoop = false;
125ef112833SThomas Raoux     } else if (!options.supportDynamicLoops) {
126ef112833SThomas Raoux       LDBG("--fewer loop iterations than pipeline stages -> BAIL");
127ef112833SThomas Raoux       return false;
128ef112833SThomas Raoux     }
129ef112833SThomas Raoux   }
130205c08b5SThomas Raoux   peelEpilogue = options.peelEpilogue;
131205c08b5SThomas Raoux   predicateFn = options.predicateFn;
132ef112833SThomas Raoux   if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
133c7592c77SNicolas Vasilache     LDBG("--no epilogue or predicate set -> BAIL");
134205c08b5SThomas Raoux     return false;
135c7592c77SNicolas Vasilache   }
136f6f88e66Sthomasraoux   std::vector<std::pair<Operation *, unsigned>> schedule;
137f6f88e66Sthomasraoux   options.getScheduleFn(forOp, schedule);
138c7592c77SNicolas Vasilache   if (schedule.empty()) {
139c7592c77SNicolas Vasilache     LDBG("--empty schedule -> BAIL");
140f6f88e66Sthomasraoux     return false;
141c7592c77SNicolas Vasilache   }
142f6f88e66Sthomasraoux 
143f6f88e66Sthomasraoux   opOrder.reserve(schedule.size());
144f6f88e66Sthomasraoux   for (auto &opSchedule : schedule) {
145f6f88e66Sthomasraoux     maxStage = std::max(maxStage, opSchedule.second);
146f6f88e66Sthomasraoux     stages[opSchedule.first] = opSchedule.second;
147f6f88e66Sthomasraoux     opOrder.push_back(opSchedule.first);
148f6f88e66Sthomasraoux   }
149f6f88e66Sthomasraoux 
150f6f88e66Sthomasraoux   // All operations need to have a stage.
151f5fe92f6SChristopher Bate   for (Operation &op : forOp.getBody()->without_terminator()) {
152ce14f7b1SKazu Hirata     if (!stages.contains(&op)) {
153f5fe92f6SChristopher Bate       op.emitOpError("not assigned a pipeline stage");
154c7592c77SNicolas Vasilache       LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL");
155f6f88e66Sthomasraoux       return false;
156f5fe92f6SChristopher Bate     }
157f5fe92f6SChristopher Bate   }
158f5fe92f6SChristopher Bate 
159c933bd81SThomas Raoux   if (!verifySchedule()) {
160c933bd81SThomas Raoux     LDBG("--invalid schedule: " << op << " -> BAIL");
161c933bd81SThomas Raoux     return false;
162c933bd81SThomas Raoux   }
163c933bd81SThomas Raoux 
164f5fe92f6SChristopher Bate   // Currently, we do not support assigning stages to ops in nested regions. The
165f5fe92f6SChristopher Bate   // block of all operations assigned a stage should be the single `scf.for`
166f5fe92f6SChristopher Bate   // body block.
167f5fe92f6SChristopher Bate   for (const auto &[op, stageNum] : stages) {
168f5fe92f6SChristopher Bate     (void)stageNum;
169f5fe92f6SChristopher Bate     if (op == forOp.getBody()->getTerminator()) {
170f5fe92f6SChristopher Bate       op->emitError("terminator should not be assigned a stage");
171c7592c77SNicolas Vasilache       LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL");
172f5fe92f6SChristopher Bate       return false;
173f5fe92f6SChristopher Bate     }
174f5fe92f6SChristopher Bate     if (op->getBlock() != forOp.getBody()) {
175f5fe92f6SChristopher Bate       op->emitOpError("the owning Block of all operations assigned a stage "
176f5fe92f6SChristopher Bate                       "should be the loop body block");
177c7592c77SNicolas Vasilache       LDBG("--the owning Block of all operations assigned a stage "
178c7592c77SNicolas Vasilache            "should be the loop body block: "
179c7592c77SNicolas Vasilache            << *op << " -> BAIL");
180f5fe92f6SChristopher Bate       return false;
181f5fe92f6SChristopher Bate     }
182f5fe92f6SChristopher Bate   }
183f6f88e66Sthomasraoux 
184e66f97e8SKeren Zhou   // Support only loop-carried dependencies with a distance of one iteration or
185e66f97e8SKeren Zhou   // those defined outside of the loop. This means that any dependency within a
186e66f97e8SKeren Zhou   // loop should either be on the immediately preceding iteration, the current
187e66f97e8SKeren Zhou   // iteration, or on variables whose values are set before entering the loop.
18845cb4140Sthomasraoux   if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
18945cb4140Sthomasraoux                    [this](Value operand) {
19045cb4140Sthomasraoux                      Operation *def = operand.getDefiningOp();
191e66f97e8SKeren Zhou                      return !def ||
192e66f97e8SKeren Zhou                             (!stages.contains(def) && forOp->isAncestor(def));
193c7592c77SNicolas Vasilache                    })) {
194e66f97e8SKeren Zhou     LDBG("--only support loop carried dependency with a distance of 1 or "
195e66f97e8SKeren Zhou          "defined outside of the loop -> BAIL");
196f6f88e66Sthomasraoux     return false;
197c7592c77SNicolas Vasilache   }
1980736bbd7SThomas Raoux   annotateFn = options.annotateFn;
199f6f88e66Sthomasraoux   return true;
200f6f88e66Sthomasraoux }
201f6f88e66Sthomasraoux 
20256954a53Spawelszczerbuk /// Find operands of all the nested operations within `op`.
20356954a53Spawelszczerbuk static SetVector<Value> getNestedOperands(Operation *op) {
20456954a53Spawelszczerbuk   SetVector<Value> operands;
20556954a53Spawelszczerbuk   op->walk([&](Operation *nestedOp) {
20656954a53Spawelszczerbuk     for (Value operand : nestedOp->getOperands()) {
20756954a53Spawelszczerbuk       operands.insert(operand);
20856954a53Spawelszczerbuk     }
20956954a53Spawelszczerbuk   });
21056954a53Spawelszczerbuk   return operands;
21156954a53Spawelszczerbuk }
21256954a53Spawelszczerbuk 
213c933bd81SThomas Raoux /// Compute unrolled cycles of each op (consumer) and verify that each op is
214c933bd81SThomas Raoux /// scheduled after its operands (producers) while adjusting for the distance
215c933bd81SThomas Raoux /// between producer and consumer.
216c933bd81SThomas Raoux bool LoopPipelinerInternal::verifySchedule() {
217c933bd81SThomas Raoux   int64_t numCylesPerIter = opOrder.size();
218c933bd81SThomas Raoux   // Pre-compute the unrolled cycle of each op.
219c933bd81SThomas Raoux   DenseMap<Operation *, int64_t> unrolledCyles;
220c933bd81SThomas Raoux   for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
221c933bd81SThomas Raoux     Operation *def = opOrder[cycle];
222c933bd81SThomas Raoux     auto it = stages.find(def);
223c933bd81SThomas Raoux     assert(it != stages.end());
224c933bd81SThomas Raoux     int64_t stage = it->second;
225c933bd81SThomas Raoux     unrolledCyles[def] = cycle + stage * numCylesPerIter;
226c933bd81SThomas Raoux   }
227c933bd81SThomas Raoux   for (Operation *consumer : opOrder) {
228c933bd81SThomas Raoux     int64_t consumerCycle = unrolledCyles[consumer];
22956954a53Spawelszczerbuk     for (Value operand : getNestedOperands(consumer)) {
230c933bd81SThomas Raoux       auto [producer, distance] = getDefiningOpAndDistance(operand);
231c933bd81SThomas Raoux       if (!producer)
232c933bd81SThomas Raoux         continue;
233c933bd81SThomas Raoux       auto it = unrolledCyles.find(producer);
234c933bd81SThomas Raoux       // Skip producer coming from outside the loop.
235c933bd81SThomas Raoux       if (it == unrolledCyles.end())
236c933bd81SThomas Raoux         continue;
237c933bd81SThomas Raoux       int64_t producerCycle = it->second;
238c933bd81SThomas Raoux       if (consumerCycle < producerCycle - numCylesPerIter * distance) {
239c933bd81SThomas Raoux         consumer->emitError("operation scheduled before its operands");
240c933bd81SThomas Raoux         return false;
241c933bd81SThomas Raoux       }
242c933bd81SThomas Raoux     }
243c933bd81SThomas Raoux   }
244c933bd81SThomas Raoux   return true;
245c933bd81SThomas Raoux }
246c933bd81SThomas Raoux 
247f5fe92f6SChristopher Bate /// Clone `op` and call `callback` on the cloned op's oeprands as well as any
248f5fe92f6SChristopher Bate /// operands of nested ops that:
249f5fe92f6SChristopher Bate /// 1) aren't defined within the new op or
250f5fe92f6SChristopher Bate /// 2) are block arguments.
251f5fe92f6SChristopher Bate static Operation *
252f5fe92f6SChristopher Bate cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
253f5fe92f6SChristopher Bate                        function_ref<void(OpOperand *newOperand)> callback) {
254f5fe92f6SChristopher Bate   Operation *clone = rewriter.clone(*op);
25556954a53Spawelszczerbuk   clone->walk<WalkOrder::PreOrder>([&](Operation *nested) {
25656954a53Spawelszczerbuk     // 'clone' itself will be visited first.
257f5fe92f6SChristopher Bate     for (OpOperand &operand : nested->getOpOperands()) {
258f5fe92f6SChristopher Bate       Operation *def = operand.get().getDefiningOp();
2595550c821STres Popp       if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
260f5fe92f6SChristopher Bate         callback(&operand);
261f5fe92f6SChristopher Bate     }
262f5fe92f6SChristopher Bate   });
263f5fe92f6SChristopher Bate   return clone;
264f5fe92f6SChristopher Bate }
265f5fe92f6SChristopher Bate 
26618926666SSJW LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
2677c900811Spawelszczerbuk   // Initialize the iteration argument to the loop initial values.
2683cd2a0bcSMatthias Springer   for (auto [arg, operand] :
2693cd2a0bcSMatthias Springer        llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
27045cb4140Sthomasraoux     setValueMapping(arg, operand.get(), 0);
27145cb4140Sthomasraoux   }
27245cb4140Sthomasraoux   auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
273ef112833SThomas Raoux   Location loc = forOp.getLoc();
274ef112833SThomas Raoux   SmallVector<Value> predicates(maxStage);
275f6f88e66Sthomasraoux   for (int64_t i = 0; i < maxStage; i++) {
276ef112833SThomas Raoux     if (dynamicLoop) {
277ef112833SThomas Raoux       Type t = ub.getType();
278ef112833SThomas Raoux       // pred = ub > lb + (i * step)
279ef112833SThomas Raoux       Value iv = rewriter.create<arith::AddIOp>(
280ef112833SThomas Raoux           loc, lb,
281ef112833SThomas Raoux           rewriter.create<arith::MulIOp>(
282ef112833SThomas Raoux               loc, step,
283ef112833SThomas Raoux               rewriter.create<arith::ConstantOp>(
284ef112833SThomas Raoux                   loc, rewriter.getIntegerAttr(t, i))));
285ef112833SThomas Raoux       predicates[i] = rewriter.create<arith::CmpIOp>(
286ef112833SThomas Raoux           loc, arith::CmpIPredicate::slt, iv, ub);
287ef112833SThomas Raoux     }
288ef112833SThomas Raoux 
289f6f88e66Sthomasraoux     // special handling for induction variable as the increment is implicit.
290ef112833SThomas Raoux     // iv = lb + i * step
291ef112833SThomas Raoux     Type t = lb.getType();
292ef112833SThomas Raoux     Value iv = rewriter.create<arith::AddIOp>(
293ef112833SThomas Raoux         loc, lb,
294ef112833SThomas Raoux         rewriter.create<arith::MulIOp>(
295ef112833SThomas Raoux             loc, step,
296ef112833SThomas Raoux             rewriter.create<arith::ConstantOp>(loc,
297ef112833SThomas Raoux                                                rewriter.getIntegerAttr(t, i))));
298f6f88e66Sthomasraoux     setValueMapping(forOp.getInductionVar(), iv, i);
299f6f88e66Sthomasraoux     for (Operation *op : opOrder) {
300f6f88e66Sthomasraoux       if (stages[op] > i)
301f6f88e66Sthomasraoux         continue;
302f5fe92f6SChristopher Bate       Operation *newOp =
303f5fe92f6SChristopher Bate           cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
304f5fe92f6SChristopher Bate             auto it = valueMapping.find(newOperand->get());
305f5fe92f6SChristopher Bate             if (it != valueMapping.end()) {
306f5fe92f6SChristopher Bate               Value replacement = it->second[i - stages[op]];
307f5fe92f6SChristopher Bate               newOperand->set(replacement);
308f6f88e66Sthomasraoux             }
309f5fe92f6SChristopher Bate           });
310ef112833SThomas Raoux       int predicateIdx = i - stages[op];
311ef112833SThomas Raoux       if (predicates[predicateIdx]) {
312ebf05993SSJW         OpBuilder::InsertionGuard insertGuard(rewriter);
313ef112833SThomas Raoux         newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
31418926666SSJW         if (newOp == nullptr)
31518926666SSJW           return failure();
316ef112833SThomas Raoux       }
3170736bbd7SThomas Raoux       if (annotateFn)
3180736bbd7SThomas Raoux         annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
319f6f88e66Sthomasraoux       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
3207c900811Spawelszczerbuk         Value source = newOp->getResult(destId);
32145cb4140Sthomasraoux         // If the value is a loop carried dependency update the loop argument
32245cb4140Sthomasraoux         for (OpOperand &operand : yield->getOpOperands()) {
32345cb4140Sthomasraoux           if (operand.get() != op->getResult(destId))
32445cb4140Sthomasraoux             continue;
3257c900811Spawelszczerbuk           if (predicates[predicateIdx] &&
3267c900811Spawelszczerbuk               !forOp.getResult(operand.getOperandNumber()).use_empty()) {
3277c900811Spawelszczerbuk             // If the value is used outside the loop, we need to make sure we
3287c900811Spawelszczerbuk             // return the correct version of it.
3297c900811Spawelszczerbuk             Value prevValue = valueMapping
3307c900811Spawelszczerbuk                 [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
3317c900811Spawelszczerbuk                 [i - stages[op]];
3327c900811Spawelszczerbuk             source = rewriter.create<arith::SelectOp>(
3337c900811Spawelszczerbuk                 loc, predicates[predicateIdx], source, prevValue);
33445cb4140Sthomasraoux           }
3357c900811Spawelszczerbuk           setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
3367c900811Spawelszczerbuk                           source, i - stages[op] + 1);
3377c900811Spawelszczerbuk         }
3387c900811Spawelszczerbuk         setValueMapping(op->getResult(destId), newOp->getResult(destId),
3397c900811Spawelszczerbuk                         i - stages[op]);
340f6f88e66Sthomasraoux       }
341f6f88e66Sthomasraoux     }
342f6f88e66Sthomasraoux   }
34318926666SSJW   return success();
344f6f88e66Sthomasraoux }
345f6f88e66Sthomasraoux 
346f6f88e66Sthomasraoux llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
347f6f88e66Sthomasraoux LoopPipelinerInternal::analyzeCrossStageValues() {
348f6f88e66Sthomasraoux   llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
349f6f88e66Sthomasraoux   for (Operation *op : opOrder) {
350f6f88e66Sthomasraoux     unsigned stage = stages[op];
351f5fe92f6SChristopher Bate 
352f5fe92f6SChristopher Bate     auto analyzeOperand = [&](OpOperand &operand) {
35319e068b0SThomas Raoux       auto [def, distance] = getDefiningOpAndDistance(operand.get());
354f6f88e66Sthomasraoux       if (!def)
355f5fe92f6SChristopher Bate         return;
356f6f88e66Sthomasraoux       auto defStage = stages.find(def);
35719e068b0SThomas Raoux       if (defStage == stages.end() || defStage->second == stage ||
35819e068b0SThomas Raoux           defStage->second == stage + distance)
359f5fe92f6SChristopher Bate         return;
360f6f88e66Sthomasraoux       assert(stage > defStage->second);
361f6f88e66Sthomasraoux       LiverangeInfo &info = crossStageValues[operand.get()];
362f6f88e66Sthomasraoux       info.defStage = defStage->second;
363f6f88e66Sthomasraoux       info.lastUseStage = std::max(info.lastUseStage, stage);
364f5fe92f6SChristopher Bate     };
365f5fe92f6SChristopher Bate 
366f5fe92f6SChristopher Bate     for (OpOperand &operand : op->getOpOperands())
367f5fe92f6SChristopher Bate       analyzeOperand(operand);
368f5fe92f6SChristopher Bate     visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
369f5fe92f6SChristopher Bate       analyzeOperand(*operand);
370f5fe92f6SChristopher Bate     });
371f6f88e66Sthomasraoux   }
372f6f88e66Sthomasraoux   return crossStageValues;
373f6f88e66Sthomasraoux }
374f6f88e66Sthomasraoux 
37519e068b0SThomas Raoux std::pair<Operation *, int64_t>
37619e068b0SThomas Raoux LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
37719e068b0SThomas Raoux   int64_t distance = 0;
37819e068b0SThomas Raoux   if (auto arg = dyn_cast<BlockArgument>(value)) {
37919e068b0SThomas Raoux     if (arg.getOwner() != forOp.getBody())
38019e068b0SThomas Raoux       return {nullptr, 0};
38119e068b0SThomas Raoux     // Ignore induction variable.
38219e068b0SThomas Raoux     if (arg.getArgNumber() == 0)
38319e068b0SThomas Raoux       return {nullptr, 0};
38419e068b0SThomas Raoux     distance++;
38519e068b0SThomas Raoux     value =
38619e068b0SThomas Raoux         forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
38719e068b0SThomas Raoux   }
38819e068b0SThomas Raoux   Operation *def = value.getDefiningOp();
38919e068b0SThomas Raoux   if (!def)
39019e068b0SThomas Raoux     return {nullptr, 0};
39119e068b0SThomas Raoux   return {def, distance};
39219e068b0SThomas Raoux }
39319e068b0SThomas Raoux 
394f6f88e66Sthomasraoux scf::ForOp LoopPipelinerInternal::createKernelLoop(
395f6f88e66Sthomasraoux     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
396f6f88e66Sthomasraoux         &crossStageValues,
3971cff4cbdSNicolas Vasilache     RewriterBase &rewriter,
398f6f88e66Sthomasraoux     llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
399f6f88e66Sthomasraoux   // Creates the list of initial values associated to values used across
400f6f88e66Sthomasraoux   // stages. The initial values come from the prologue created above.
401f6f88e66Sthomasraoux   // Keep track of the kernel argument associated to each version of the
402f6f88e66Sthomasraoux   // values passed to the kernel.
40345cb4140Sthomasraoux   llvm::SmallVector<Value> newLoopArg;
40445cb4140Sthomasraoux   // For existing loop argument initialize them with the right version from the
40545cb4140Sthomasraoux   // prologue.
406e4853be2SMehdi Amini   for (const auto &retVal :
40745cb4140Sthomasraoux        llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
40845cb4140Sthomasraoux     Operation *def = retVal.value().getDefiningOp();
409e66f97e8SKeren Zhou     assert(def && "Only support loop carried dependencies of distance of 1 or "
410e66f97e8SKeren Zhou                   "outside the loop");
411e66f97e8SKeren Zhou     auto defStage = stages.find(def);
412e66f97e8SKeren Zhou     if (defStage != stages.end()) {
413e66f97e8SKeren Zhou       Value valueVersion =
414e66f97e8SKeren Zhou           valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
415e66f97e8SKeren Zhou                       [maxStage - defStage->second];
41645cb4140Sthomasraoux       assert(valueVersion);
41745cb4140Sthomasraoux       newLoopArg.push_back(valueVersion);
418e66f97e8SKeren Zhou     } else
419e66f97e8SKeren Zhou       newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
42045cb4140Sthomasraoux   }
421f6f88e66Sthomasraoux   for (auto escape : crossStageValues) {
422f6f88e66Sthomasraoux     LiverangeInfo &info = escape.second;
423f6f88e66Sthomasraoux     Value value = escape.first;
424f6f88e66Sthomasraoux     for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
425f6f88e66Sthomasraoux          stageIdx++) {
426f6f88e66Sthomasraoux       Value valueVersion =
427f6f88e66Sthomasraoux           valueMapping[value][maxStage - info.lastUseStage + stageIdx];
428f6f88e66Sthomasraoux       assert(valueVersion);
429f6f88e66Sthomasraoux       newLoopArg.push_back(valueVersion);
430f6f88e66Sthomasraoux       loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
431f6f88e66Sthomasraoux                                            stageIdx)] = newLoopArg.size() - 1;
432f6f88e66Sthomasraoux     }
433f6f88e66Sthomasraoux   }
434f6f88e66Sthomasraoux 
435205c08b5SThomas Raoux   // Create the new kernel loop. When we peel the epilgue we need to peel
436205c08b5SThomas Raoux   // `numStages - 1` iterations. Then we adjust the upper bound to remove those
437205c08b5SThomas Raoux   // iterations.
438205c08b5SThomas Raoux   Value newUb = forOp.getUpperBound();
439ef112833SThomas Raoux   if (peelEpilogue) {
440ef112833SThomas Raoux     Type t = ub.getType();
441ef112833SThomas Raoux     Location loc = forOp.getLoc();
442ef112833SThomas Raoux     // newUb = ub - maxStage * step
443ef112833SThomas Raoux     Value maxStageValue = rewriter.create<arith::ConstantOp>(
444ef112833SThomas Raoux         loc, rewriter.getIntegerAttr(t, maxStage));
445ef112833SThomas Raoux     Value maxStageByStep =
446ef112833SThomas Raoux         rewriter.create<arith::MulIOp>(loc, step, maxStageValue);
447ef112833SThomas Raoux     newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep);
448ef112833SThomas Raoux   }
449c0342a2dSJacques Pienaar   auto newForOp =
450c0342a2dSJacques Pienaar       rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
451c0342a2dSJacques Pienaar                                   forOp.getStep(), newLoopArg);
452f5fe92f6SChristopher Bate   // When there are no iter args, the loop body terminator will be created.
453f5fe92f6SChristopher Bate   // Since we always create it below, remove the terminator if it was created.
454f5fe92f6SChristopher Bate   if (!newForOp.getBody()->empty())
455f5fe92f6SChristopher Bate     rewriter.eraseOp(newForOp.getBody()->getTerminator());
456f6f88e66Sthomasraoux   return newForOp;
457f6f88e66Sthomasraoux }
458f6f88e66Sthomasraoux 
459371366ceSAlex Zinenko LogicalResult LoopPipelinerInternal::createKernel(
460f6f88e66Sthomasraoux     scf::ForOp newForOp,
461f6f88e66Sthomasraoux     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
462f6f88e66Sthomasraoux         &crossStageValues,
463f6f88e66Sthomasraoux     const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
4641cff4cbdSNicolas Vasilache     RewriterBase &rewriter) {
465f6f88e66Sthomasraoux   valueMapping.clear();
466f6f88e66Sthomasraoux 
467f6f88e66Sthomasraoux   // Create the kernel, we clone instruction based on the order given by
468f6f88e66Sthomasraoux   // user and remap operands coming from a previous stages.
469f6f88e66Sthomasraoux   rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
4704d67b278SJeff Niu   IRMapping mapping;
471f6f88e66Sthomasraoux   mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
472e4853be2SMehdi Amini   for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) {
47345cb4140Sthomasraoux     mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
47445cb4140Sthomasraoux   }
475205c08b5SThomas Raoux   SmallVector<Value> predicates(maxStage + 1, nullptr);
476205c08b5SThomas Raoux   if (!peelEpilogue) {
477205c08b5SThomas Raoux     // Create a predicate for each stage except the last stage.
478ef112833SThomas Raoux     Location loc = newForOp.getLoc();
479ef112833SThomas Raoux     Type t = ub.getType();
480205c08b5SThomas Raoux     for (unsigned i = 0; i < maxStage; i++) {
481ef112833SThomas Raoux       // c = ub - (maxStage - i) * step
482ef112833SThomas Raoux       Value c = rewriter.create<arith::SubIOp>(
483ef112833SThomas Raoux           loc, ub,
484ef112833SThomas Raoux           rewriter.create<arith::MulIOp>(
485ef112833SThomas Raoux               loc, step,
486ef112833SThomas Raoux               rewriter.create<arith::ConstantOp>(
487ef112833SThomas Raoux                   loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
488ef112833SThomas Raoux 
489205c08b5SThomas Raoux       Value pred = rewriter.create<arith::CmpIOp>(
490205c08b5SThomas Raoux           newForOp.getLoc(), arith::CmpIPredicate::slt,
491205c08b5SThomas Raoux           newForOp.getInductionVar(), c);
492205c08b5SThomas Raoux       predicates[i] = pred;
493205c08b5SThomas Raoux     }
494205c08b5SThomas Raoux   }
495f6f88e66Sthomasraoux   for (Operation *op : opOrder) {
496f6f88e66Sthomasraoux     int64_t useStage = stages[op];
497f6f88e66Sthomasraoux     auto *newOp = rewriter.clone(*op, mapping);
498117db47dSThomas Raoux     SmallVector<OpOperand *> operands;
499117db47dSThomas Raoux     // Collect all the operands for the cloned op and its nested ops.
500117db47dSThomas Raoux     op->walk([&operands](Operation *nestedOp) {
501117db47dSThomas Raoux       for (OpOperand &operand : nestedOp->getOpOperands()) {
502117db47dSThomas Raoux         operands.push_back(&operand);
503117db47dSThomas Raoux       }
504117db47dSThomas Raoux     });
505117db47dSThomas Raoux     for (OpOperand *operand : operands) {
506117db47dSThomas Raoux       Operation *nestedNewOp = mapping.lookup(operand->getOwner());
507117db47dSThomas Raoux       // Special case for the induction variable uses. We replace it with a
508117db47dSThomas Raoux       // version incremented based on the stage where it is used.
509117db47dSThomas Raoux       if (operand->get() == forOp.getInductionVar()) {
510117db47dSThomas Raoux         rewriter.setInsertionPoint(newOp);
511ef112833SThomas Raoux 
512ef112833SThomas Raoux         // offset = (maxStage - stages[op]) * step
513ef112833SThomas Raoux         Type t = step.getType();
514ef112833SThomas Raoux         Value offset = rewriter.create<arith::MulIOp>(
515ef112833SThomas Raoux             forOp.getLoc(), step,
516ef112833SThomas Raoux             rewriter.create<arith::ConstantOp>(
517ef112833SThomas Raoux                 forOp.getLoc(),
518ef112833SThomas Raoux                 rewriter.getIntegerAttr(t, maxStage - stages[op])));
519117db47dSThomas Raoux         Value iv = rewriter.create<arith::AddIOp>(
520117db47dSThomas Raoux             forOp.getLoc(), newForOp.getInductionVar(), offset);
521117db47dSThomas Raoux         nestedNewOp->setOperand(operand->getOperandNumber(), iv);
522117db47dSThomas Raoux         rewriter.setInsertionPointAfter(newOp);
523117db47dSThomas Raoux         continue;
524117db47dSThomas Raoux       }
52519e068b0SThomas Raoux       Value source = operand->get();
52619e068b0SThomas Raoux       auto arg = dyn_cast<BlockArgument>(source);
527117db47dSThomas Raoux       if (arg && arg.getOwner() == forOp.getBody()) {
528117db47dSThomas Raoux         Value ret = forOp.getBody()->getTerminator()->getOperand(
529117db47dSThomas Raoux             arg.getArgNumber() - 1);
530117db47dSThomas Raoux         Operation *dep = ret.getDefiningOp();
531117db47dSThomas Raoux         if (!dep)
532117db47dSThomas Raoux           continue;
533117db47dSThomas Raoux         auto stageDep = stages.find(dep);
534117db47dSThomas Raoux         if (stageDep == stages.end() || stageDep->second == useStage)
535117db47dSThomas Raoux           continue;
53619e068b0SThomas Raoux         // If the value is a loop carried value coming from stage N + 1 remap,
53719e068b0SThomas Raoux         // it will become a direct use.
53819e068b0SThomas Raoux         if (stageDep->second == useStage + 1) {
539117db47dSThomas Raoux           nestedNewOp->setOperand(operand->getOperandNumber(),
540117db47dSThomas Raoux                                   mapping.lookupOrDefault(ret));
541117db47dSThomas Raoux           continue;
542117db47dSThomas Raoux         }
54319e068b0SThomas Raoux         source = ret;
54419e068b0SThomas Raoux       }
545117db47dSThomas Raoux       // For operands defined in a previous stage we need to remap it to use
546117db47dSThomas Raoux       // the correct region argument. We look for the right version of the
547117db47dSThomas Raoux       // Value based on the stage where it is used.
54819e068b0SThomas Raoux       Operation *def = source.getDefiningOp();
549117db47dSThomas Raoux       if (!def)
550117db47dSThomas Raoux         continue;
551117db47dSThomas Raoux       auto stageDef = stages.find(def);
552117db47dSThomas Raoux       if (stageDef == stages.end() || stageDef->second == useStage)
553117db47dSThomas Raoux         continue;
554117db47dSThomas Raoux       auto remap = loopArgMap.find(
555117db47dSThomas Raoux           std::make_pair(operand->get(), useStage - stageDef->second));
556117db47dSThomas Raoux       assert(remap != loopArgMap.end());
557117db47dSThomas Raoux       nestedNewOp->setOperand(operand->getOperandNumber(),
558117db47dSThomas Raoux                               newForOp.getRegionIterArgs()[remap->second]);
559117db47dSThomas Raoux     }
560f5fe92f6SChristopher Bate 
561205c08b5SThomas Raoux     if (predicates[useStage]) {
562ebf05993SSJW       OpBuilder::InsertionGuard insertGuard(rewriter);
5631cff4cbdSNicolas Vasilache       newOp = predicateFn(rewriter, newOp, predicates[useStage]);
564371366ceSAlex Zinenko       if (!newOp)
565371366ceSAlex Zinenko         return failure();
566205c08b5SThomas Raoux       // Remap the results to the new predicated one.
567205c08b5SThomas Raoux       for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
568205c08b5SThomas Raoux         mapping.map(std::get<0>(values), std::get<1>(values));
569205c08b5SThomas Raoux     }
5700736bbd7SThomas Raoux     if (annotateFn)
5710736bbd7SThomas Raoux       annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
572f6f88e66Sthomasraoux   }
573f6f88e66Sthomasraoux 
574f6f88e66Sthomasraoux   // Collect the Values that need to be returned by the forOp. For each
575f6f88e66Sthomasraoux   // value we need to have `LastUseStage - DefStage` number of versions
576f6f88e66Sthomasraoux   // returned.
577f6f88e66Sthomasraoux   // We create a mapping between original values and the associated loop
578f6f88e66Sthomasraoux   // returned values that will be needed by the epilogue.
579f6f88e66Sthomasraoux   llvm::SmallVector<Value> yieldOperands;
58019e068b0SThomas Raoux   for (OpOperand &yieldOperand :
58119e068b0SThomas Raoux        forOp.getBody()->getTerminator()->getOpOperands()) {
58219e068b0SThomas Raoux     Value source = mapping.lookupOrDefault(yieldOperand.get());
58319e068b0SThomas Raoux     // When we don't peel the epilogue and the yield value is used outside the
58419e068b0SThomas Raoux     // loop we need to make sure we return the version from numStages -
58519e068b0SThomas Raoux     // defStage.
58619e068b0SThomas Raoux     if (!peelEpilogue &&
58719e068b0SThomas Raoux         !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
58819e068b0SThomas Raoux       Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
58919e068b0SThomas Raoux       if (def) {
59019e068b0SThomas Raoux         auto defStage = stages.find(def);
59119e068b0SThomas Raoux         if (defStage != stages.end() && defStage->second < maxStage) {
59219e068b0SThomas Raoux           Value pred = predicates[defStage->second];
59319e068b0SThomas Raoux           source = rewriter.create<arith::SelectOp>(
59419e068b0SThomas Raoux               pred.getLoc(), pred, source,
59519e068b0SThomas Raoux               newForOp.getBody()
59619e068b0SThomas Raoux                   ->getArguments()[yieldOperand.getOperandNumber() + 1]);
59745cb4140Sthomasraoux         }
59819e068b0SThomas Raoux       }
59919e068b0SThomas Raoux     }
60019e068b0SThomas Raoux     yieldOperands.push_back(source);
60119e068b0SThomas Raoux   }
60219e068b0SThomas Raoux 
603f6f88e66Sthomasraoux   for (auto &it : crossStageValues) {
604f6f88e66Sthomasraoux     int64_t version = maxStage - it.second.lastUseStage + 1;
605f6f88e66Sthomasraoux     unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
606371366ceSAlex Zinenko     // add the original version to yield ops.
607371366ceSAlex Zinenko     // If there is a live range spanning across more than 2 stages we need to
608371366ceSAlex Zinenko     // add extra arg.
609f6f88e66Sthomasraoux     for (unsigned i = 1; i < numVersionReturned; i++) {
610f6f88e66Sthomasraoux       setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
611f6f88e66Sthomasraoux                       version++);
612f6f88e66Sthomasraoux       yieldOperands.push_back(
613f6f88e66Sthomasraoux           newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
614f6f88e66Sthomasraoux                                              newForOp.getNumInductionVars()]);
615f6f88e66Sthomasraoux     }
616f6f88e66Sthomasraoux     setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
617f6f88e66Sthomasraoux                     version++);
618f6f88e66Sthomasraoux     yieldOperands.push_back(mapping.lookupOrDefault(it.first));
619f6f88e66Sthomasraoux   }
62045cb4140Sthomasraoux   // Map the yield operand to the forOp returned value.
621e4853be2SMehdi Amini   for (const auto &retVal :
62245cb4140Sthomasraoux        llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
62345cb4140Sthomasraoux     Operation *def = retVal.value().getDefiningOp();
624e66f97e8SKeren Zhou     assert(def && "Only support loop carried dependencies of distance of 1 or "
625e66f97e8SKeren Zhou                   "defined outside the loop");
626e66f97e8SKeren Zhou     auto defStage = stages.find(def);
627e66f97e8SKeren Zhou     if (defStage == stages.end()) {
628e66f97e8SKeren Zhou       for (unsigned int stage = 1; stage <= maxStage; stage++)
629e66f97e8SKeren Zhou         setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
630e66f97e8SKeren Zhou                         retVal.value(), stage);
631e66f97e8SKeren Zhou     } else if (defStage->second > 0) {
63245cb4140Sthomasraoux       setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
63345cb4140Sthomasraoux                       newForOp->getResult(retVal.index()),
634e66f97e8SKeren Zhou                       maxStage - defStage->second + 1);
63545cb4140Sthomasraoux     }
63619e068b0SThomas Raoux   }
637f6f88e66Sthomasraoux   rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
638371366ceSAlex Zinenko   return success();
639f6f88e66Sthomasraoux }
640f6f88e66Sthomasraoux 
641ebf05993SSJW LogicalResult
642ebf05993SSJW LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
643ebf05993SSJW                                     llvm::SmallVector<Value> &returnValues) {
644ebf05993SSJW   Location loc = forOp.getLoc();
645*8da5aa16SSJW   Type t = lb.getType();
646*8da5aa16SSJW 
647f6f88e66Sthomasraoux   // Emit different versions of the induction variable. They will be
648f6f88e66Sthomasraoux   // removed by dead code if not used.
649ebf05993SSJW 
650*8da5aa16SSJW   auto createConst = [&](int v) {
651*8da5aa16SSJW     return rewriter.create<arith::ConstantOp>(loc,
652*8da5aa16SSJW                                               rewriter.getIntegerAttr(t, v));
653*8da5aa16SSJW   };
654*8da5aa16SSJW 
655*8da5aa16SSJW   // total_iterations = cdiv(range_diff, step);
656*8da5aa16SSJW   // - range_diff = ub - lb
657*8da5aa16SSJW   // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
658*8da5aa16SSJW   Value zero = createConst(0);
659*8da5aa16SSJW   Value one = createConst(1);
6607645d9c7SSJW   Value stepLessZero = rewriter.create<arith::CmpIOp>(
6617645d9c7SSJW       loc, arith::CmpIPredicate::slt, step, zero);
6627645d9c7SSJW   Value stepDecr =
663*8da5aa16SSJW       rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
6647645d9c7SSJW 
6657645d9c7SSJW   Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
6667645d9c7SSJW   Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
6677645d9c7SSJW   Value rangeDecr =
6687645d9c7SSJW       rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
6697645d9c7SSJW   Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
670fa089b01SSJW 
671*8da5aa16SSJW   // If total_iters < max_stage, start the epilogue at zero to match the
672*8da5aa16SSJW   // ramp-up in the prologue.
673*8da5aa16SSJW   // start_iter = max(0, total_iters - max_stage)
674*8da5aa16SSJW   Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
675*8da5aa16SSJW                                                createConst(maxStage));
676*8da5aa16SSJW   iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
677*8da5aa16SSJW 
678*8da5aa16SSJW   // Capture predicates for dynamic loops.
679ebf05993SSJW   SmallVector<Value> predicates(maxStage + 1);
680*8da5aa16SSJW 
681*8da5aa16SSJW   for (int64_t i = 1; i <= maxStage; i++) {
682ebf05993SSJW     // newLastIter = lb + step * iterI
683ef112833SThomas Raoux     Value newlastIter = rewriter.create<arith::AddIOp>(
684ebf05993SSJW         loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
685ebf05993SSJW 
686*8da5aa16SSJW     setValueMapping(forOp.getInductionVar(), newlastIter, i);
687*8da5aa16SSJW 
688*8da5aa16SSJW     // increment to next iterI
689*8da5aa16SSJW     iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
690ebf05993SSJW 
691ebf05993SSJW     if (dynamicLoop) {
692*8da5aa16SSJW       // Disable stages when `i` is greater than total_iters.
693*8da5aa16SSJW       // pred = total_iters >= i
694*8da5aa16SSJW       predicates[i] = rewriter.create<arith::CmpIOp>(
695*8da5aa16SSJW           loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
696f6f88e66Sthomasraoux     }
697ebf05993SSJW   }
698ebf05993SSJW 
699f5fe92f6SChristopher Bate   // Emit `maxStage - 1` epilogue part that includes operations from stages
700f6f88e66Sthomasraoux   // [i; maxStage].
701f6f88e66Sthomasraoux   for (int64_t i = 1; i <= maxStage; i++) {
702ebf05993SSJW     SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
703f6f88e66Sthomasraoux     for (Operation *op : opOrder) {
704f6f88e66Sthomasraoux       if (stages[op] < i)
705f6f88e66Sthomasraoux         continue;
706ebf05993SSJW       unsigned currentVersion = maxStage - stages[op] + i;
707ebf05993SSJW       unsigned nextVersion = currentVersion + 1;
708f5fe92f6SChristopher Bate       Operation *newOp =
709f5fe92f6SChristopher Bate           cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
710f5fe92f6SChristopher Bate             auto it = valueMapping.find(newOperand->get());
711f6f88e66Sthomasraoux             if (it != valueMapping.end()) {
712ebf05993SSJW               Value replacement = it->second[currentVersion];
713f5fe92f6SChristopher Bate               newOperand->set(replacement);
714f6f88e66Sthomasraoux             }
715f5fe92f6SChristopher Bate           });
716ebf05993SSJW       if (dynamicLoop) {
717ebf05993SSJW         OpBuilder::InsertionGuard insertGuard(rewriter);
718ebf05993SSJW         newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
719ebf05993SSJW         if (!newOp)
720ebf05993SSJW           return failure();
721ebf05993SSJW       }
7220736bbd7SThomas Raoux       if (annotateFn)
7230736bbd7SThomas Raoux         annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
724ebf05993SSJW 
725ebf05993SSJW       for (auto [opRes, newRes] :
726ebf05993SSJW            llvm::zip(op->getResults(), newOp->getResults())) {
727ebf05993SSJW         setValueMapping(opRes, newRes, currentVersion);
72845cb4140Sthomasraoux         // If the value is a loop carried dependency update the loop argument
72945cb4140Sthomasraoux         // mapping and keep track of the last version to replace the original
73045cb4140Sthomasraoux         // forOp uses.
73145cb4140Sthomasraoux         for (OpOperand &operand :
73245cb4140Sthomasraoux              forOp.getBody()->getTerminator()->getOpOperands()) {
733ebf05993SSJW           if (operand.get() != opRes)
73445cb4140Sthomasraoux             continue;
73545cb4140Sthomasraoux           // If the version is greater than maxStage it means it maps to the
73645cb4140Sthomasraoux           // original forOp returned value.
737ebf05993SSJW           unsigned ri = operand.getOperandNumber();
738ebf05993SSJW           returnValues[ri] = newRes;
739ebf05993SSJW           Value mapVal = forOp.getRegionIterArgs()[ri];
740ebf05993SSJW           returnMap[ri] = std::make_pair(mapVal, currentVersion);
741ebf05993SSJW           if (nextVersion <= maxStage)
742ebf05993SSJW             setValueMapping(mapVal, newRes, nextVersion);
74345cb4140Sthomasraoux         }
744ebf05993SSJW       }
745ebf05993SSJW     }
746ebf05993SSJW     if (dynamicLoop) {
747ebf05993SSJW       // Select return values from this stage (live outs) based on predication.
748ebf05993SSJW       // If the stage is valid select the peeled value, else use previous stage
749ebf05993SSJW       // value.
750ebf05993SSJW       for (auto pair : llvm::enumerate(returnValues)) {
751ebf05993SSJW         unsigned ri = pair.index();
752ebf05993SSJW         auto [mapVal, currentVersion] = returnMap[ri];
753ebf05993SSJW         if (mapVal) {
754ebf05993SSJW           unsigned nextVersion = currentVersion + 1;
755ebf05993SSJW           Value pred = predicates[currentVersion];
756ebf05993SSJW           Value prevValue = valueMapping[mapVal][currentVersion];
757ebf05993SSJW           auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
758ebf05993SSJW                                                         prevValue);
759ebf05993SSJW           returnValues[ri] = selOp;
760ebf05993SSJW           if (nextVersion <= maxStage)
761ebf05993SSJW             setValueMapping(mapVal, selOp, nextVersion);
762f6f88e66Sthomasraoux         }
763f6f88e66Sthomasraoux       }
764f6f88e66Sthomasraoux     }
765f6f88e66Sthomasraoux   }
766ebf05993SSJW   return success();
76745cb4140Sthomasraoux }
768f6f88e66Sthomasraoux 
769f6f88e66Sthomasraoux void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
770f6f88e66Sthomasraoux   auto it = valueMapping.find(key);
771f6f88e66Sthomasraoux   // If the value is not in the map yet add a vector big enough to store all
772f6f88e66Sthomasraoux   // versions.
773f6f88e66Sthomasraoux   if (it == valueMapping.end())
774f6f88e66Sthomasraoux     it =
775f6f88e66Sthomasraoux         valueMapping
776f6f88e66Sthomasraoux             .insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1)))
777f6f88e66Sthomasraoux             .first;
778f6f88e66Sthomasraoux   it->second[idx] = el;
779f6f88e66Sthomasraoux }
780f6f88e66Sthomasraoux 
7815f0d4f20SAlex Zinenko } // namespace
7825f0d4f20SAlex Zinenko 
7831cff4cbdSNicolas Vasilache FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
784371366ceSAlex Zinenko                                             const PipeliningOption &options,
785371366ceSAlex Zinenko                                             bool *modifiedIR) {
786371366ceSAlex Zinenko   if (modifiedIR)
787371366ceSAlex Zinenko     *modifiedIR = false;
788f6f88e66Sthomasraoux   LoopPipelinerInternal pipeliner;
789f6f88e66Sthomasraoux   if (!pipeliner.initializeLoopInfo(forOp, options))
790f6f88e66Sthomasraoux     return failure();
791f6f88e66Sthomasraoux 
792371366ceSAlex Zinenko   if (modifiedIR)
793371366ceSAlex Zinenko     *modifiedIR = true;
794371366ceSAlex Zinenko 
795f6f88e66Sthomasraoux   // 1. Emit prologue.
79618926666SSJW   if (failed(pipeliner.emitPrologue(rewriter)))
79718926666SSJW     return failure();
798f6f88e66Sthomasraoux 
799f6f88e66Sthomasraoux   // 2. Track values used across stages. When a value cross stages it will
800f6f88e66Sthomasraoux   // need to be passed as loop iteration arguments.
801f6f88e66Sthomasraoux   // We first collect the values that are used in a different stage than where
802f6f88e66Sthomasraoux   // they are defined.
803f6f88e66Sthomasraoux   llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
804f6f88e66Sthomasraoux       crossStageValues = pipeliner.analyzeCrossStageValues();
805f6f88e66Sthomasraoux 
806f6f88e66Sthomasraoux   // Mapping between original loop values used cross stage and the block
807f6f88e66Sthomasraoux   // arguments associated after pipelining. A Value may map to several
808f6f88e66Sthomasraoux   // arguments if its liverange spans across more than 2 stages.
809f6f88e66Sthomasraoux   llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
810f6f88e66Sthomasraoux   // 3. Create the new kernel loop and return the block arguments mapping.
811f6f88e66Sthomasraoux   ForOp newForOp =
812f6f88e66Sthomasraoux       pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
813f6f88e66Sthomasraoux   // Create the kernel block, order ops based on user choice and remap
814f6f88e66Sthomasraoux   // operands.
815371366ceSAlex Zinenko   if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
816371366ceSAlex Zinenko                                     rewriter)))
817371366ceSAlex Zinenko     return failure();
818f6f88e66Sthomasraoux 
819205c08b5SThomas Raoux   llvm::SmallVector<Value> returnValues =
820205c08b5SThomas Raoux       newForOp.getResults().take_front(forOp->getNumResults());
821205c08b5SThomas Raoux   if (options.peelEpilogue) {
822f6f88e66Sthomasraoux     // 4. Emit the epilogue after the new forOp.
823f6f88e66Sthomasraoux     rewriter.setInsertionPointAfter(newForOp);
824ebf05993SSJW     if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
825ebf05993SSJW       return failure();
826205c08b5SThomas Raoux   }
827f6f88e66Sthomasraoux   // 5. Erase the original loop and replace the uses with the epilogue output.
828f6f88e66Sthomasraoux   if (forOp->getNumResults() > 0)
82945cb4140Sthomasraoux     rewriter.replaceOp(forOp, returnValues);
830f6f88e66Sthomasraoux   else
831f6f88e66Sthomasraoux     rewriter.eraseOp(forOp);
832f6f88e66Sthomasraoux 
8335f0d4f20SAlex Zinenko   return newForOp;
834f6f88e66Sthomasraoux }
835f6f88e66Sthomasraoux 
836f6f88e66Sthomasraoux void mlir::scf::populateSCFLoopPipeliningPatterns(
837f6f88e66Sthomasraoux     RewritePatternSet &patterns, const PipeliningOption &options) {
8385f0d4f20SAlex Zinenko   patterns.add<ForLoopPipeliningPattern>(options, patterns.getContext());
839f6f88e66Sthomasraoux }
840