xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (revision 8da5aa16f65bc297663573bacd3030f975b9fcde)
1 //===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop software pipelining
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
16 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
17 #include "mlir/Dialect/SCF/Utils/Utils.h"
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/RegionUtils.h"
21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/MathExtras.h"
24 
25 #define DEBUG_TYPE "scf-loop-pipelining"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 namespace {
33 
34 /// Helper to keep internal information during pipelining transformation.
35 struct LoopPipelinerInternal {
36   /// Coarse liverange information for ops used across stages.
37   struct LiverangeInfo {
38     unsigned lastUseStage = 0;
39     unsigned defStage = 0;
40   };
41 
42 protected:
43   ForOp forOp;
44   unsigned maxStage = 0;
45   DenseMap<Operation *, unsigned> stages;
46   std::vector<Operation *> opOrder;
47   Value ub;
48   Value lb;
49   Value step;
50   bool dynamicLoop;
51   PipeliningOption::AnnotationlFnType annotateFn = nullptr;
52   bool peelEpilogue;
53   PipeliningOption::PredicateOpFn predicateFn = nullptr;
54 
55   // When peeling the kernel we generate several version of each value for
56   // different stage of the prologue. This map tracks the mapping between
57   // original Values in the loop and the different versions
58   // peeled from the loop.
59   DenseMap<Value, llvm::SmallVector<Value>> valueMapping;
60 
61   /// Assign a value to `valueMapping`, this means `val` represents the version
62   /// `idx` of `key` in the epilogue.
63   void setValueMapping(Value key, Value el, int64_t idx);
64 
65   /// Return the defining op of the given value, if the Value is an argument of
66   /// the loop return the associated defining op in the loop and its distance to
67   /// the Value.
68   std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
69 
70   /// Return true if the schedule is possible and return false otherwise. A
71   /// schedule is correct if all definitions are scheduled before uses.
72   bool verifySchedule();
73 
74 public:
75   /// Initalize the information for the given `op`, return true if it
76   /// satisfies the pre-condition to apply pipelining.
77   bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
78   /// Emits the prologue, this creates `maxStage - 1` part which will contain
79   /// operations from stages [0; i], where i is the part index.
80   LogicalResult emitPrologue(RewriterBase &rewriter);
81   /// Gather liverange information for Values that are used in a different stage
82   /// than its definition.
83   llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
84   scf::ForOp createKernelLoop(
85       const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
86       RewriterBase &rewriter,
87       llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
88   /// Emits the pipelined kernel. This clones loop operations following user
89   /// order and remaps operands defined in a different stage as their use.
90   LogicalResult createKernel(
91       scf::ForOp newForOp,
92       const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
93       const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
94       RewriterBase &rewriter);
95   /// Emits the epilogue, this creates `maxStage - 1` part which will contain
96   /// operations from stages [i; maxStage], where i is the part index.
97   LogicalResult emitEpilogue(RewriterBase &rewriter,
98                              llvm::SmallVector<Value> &returnValues);
99 };
100 
101 bool LoopPipelinerInternal::initializeLoopInfo(
102     ForOp op, const PipeliningOption &options) {
103   LDBG("Start initializeLoopInfo");
104   forOp = op;
105   ub = forOp.getUpperBound();
106   lb = forOp.getLowerBound();
107   step = forOp.getStep();
108 
109   dynamicLoop = true;
110   auto upperBoundCst = getConstantIntValue(ub);
111   auto lowerBoundCst = getConstantIntValue(lb);
112   auto stepCst = getConstantIntValue(step);
113   if (!upperBoundCst || !lowerBoundCst || !stepCst) {
114     if (!options.supportDynamicLoops) {
115       LDBG("--dynamic loop not supported -> BAIL");
116       return false;
117     }
118   } else {
119     int64_t ubImm = upperBoundCst.value();
120     int64_t lbImm = lowerBoundCst.value();
121     int64_t stepImm = stepCst.value();
122     int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
123     if (numIteration > maxStage) {
124       dynamicLoop = false;
125     } else if (!options.supportDynamicLoops) {
126       LDBG("--fewer loop iterations than pipeline stages -> BAIL");
127       return false;
128     }
129   }
130   peelEpilogue = options.peelEpilogue;
131   predicateFn = options.predicateFn;
132   if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
133     LDBG("--no epilogue or predicate set -> BAIL");
134     return false;
135   }
136   std::vector<std::pair<Operation *, unsigned>> schedule;
137   options.getScheduleFn(forOp, schedule);
138   if (schedule.empty()) {
139     LDBG("--empty schedule -> BAIL");
140     return false;
141   }
142 
143   opOrder.reserve(schedule.size());
144   for (auto &opSchedule : schedule) {
145     maxStage = std::max(maxStage, opSchedule.second);
146     stages[opSchedule.first] = opSchedule.second;
147     opOrder.push_back(opSchedule.first);
148   }
149 
150   // All operations need to have a stage.
151   for (Operation &op : forOp.getBody()->without_terminator()) {
152     if (!stages.contains(&op)) {
153       op.emitOpError("not assigned a pipeline stage");
154       LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL");
155       return false;
156     }
157   }
158 
159   if (!verifySchedule()) {
160     LDBG("--invalid schedule: " << op << " -> BAIL");
161     return false;
162   }
163 
164   // Currently, we do not support assigning stages to ops in nested regions. The
165   // block of all operations assigned a stage should be the single `scf.for`
166   // body block.
167   for (const auto &[op, stageNum] : stages) {
168     (void)stageNum;
169     if (op == forOp.getBody()->getTerminator()) {
170       op->emitError("terminator should not be assigned a stage");
171       LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL");
172       return false;
173     }
174     if (op->getBlock() != forOp.getBody()) {
175       op->emitOpError("the owning Block of all operations assigned a stage "
176                       "should be the loop body block");
177       LDBG("--the owning Block of all operations assigned a stage "
178            "should be the loop body block: "
179            << *op << " -> BAIL");
180       return false;
181     }
182   }
183 
184   // Support only loop-carried dependencies with a distance of one iteration or
185   // those defined outside of the loop. This means that any dependency within a
186   // loop should either be on the immediately preceding iteration, the current
187   // iteration, or on variables whose values are set before entering the loop.
188   if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
189                    [this](Value operand) {
190                      Operation *def = operand.getDefiningOp();
191                      return !def ||
192                             (!stages.contains(def) && forOp->isAncestor(def));
193                    })) {
194     LDBG("--only support loop carried dependency with a distance of 1 or "
195          "defined outside of the loop -> BAIL");
196     return false;
197   }
198   annotateFn = options.annotateFn;
199   return true;
200 }
201 
202 /// Find operands of all the nested operations within `op`.
203 static SetVector<Value> getNestedOperands(Operation *op) {
204   SetVector<Value> operands;
205   op->walk([&](Operation *nestedOp) {
206     for (Value operand : nestedOp->getOperands()) {
207       operands.insert(operand);
208     }
209   });
210   return operands;
211 }
212 
213 /// Compute unrolled cycles of each op (consumer) and verify that each op is
214 /// scheduled after its operands (producers) while adjusting for the distance
215 /// between producer and consumer.
216 bool LoopPipelinerInternal::verifySchedule() {
217   int64_t numCylesPerIter = opOrder.size();
218   // Pre-compute the unrolled cycle of each op.
219   DenseMap<Operation *, int64_t> unrolledCyles;
220   for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
221     Operation *def = opOrder[cycle];
222     auto it = stages.find(def);
223     assert(it != stages.end());
224     int64_t stage = it->second;
225     unrolledCyles[def] = cycle + stage * numCylesPerIter;
226   }
227   for (Operation *consumer : opOrder) {
228     int64_t consumerCycle = unrolledCyles[consumer];
229     for (Value operand : getNestedOperands(consumer)) {
230       auto [producer, distance] = getDefiningOpAndDistance(operand);
231       if (!producer)
232         continue;
233       auto it = unrolledCyles.find(producer);
234       // Skip producer coming from outside the loop.
235       if (it == unrolledCyles.end())
236         continue;
237       int64_t producerCycle = it->second;
238       if (consumerCycle < producerCycle - numCylesPerIter * distance) {
239         consumer->emitError("operation scheduled before its operands");
240         return false;
241       }
242     }
243   }
244   return true;
245 }
246 
247 /// Clone `op` and call `callback` on the cloned op's oeprands as well as any
248 /// operands of nested ops that:
249 /// 1) aren't defined within the new op or
250 /// 2) are block arguments.
251 static Operation *
252 cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
253                        function_ref<void(OpOperand *newOperand)> callback) {
254   Operation *clone = rewriter.clone(*op);
255   clone->walk<WalkOrder::PreOrder>([&](Operation *nested) {
256     // 'clone' itself will be visited first.
257     for (OpOperand &operand : nested->getOpOperands()) {
258       Operation *def = operand.get().getDefiningOp();
259       if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
260         callback(&operand);
261     }
262   });
263   return clone;
264 }
265 
266 LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
267   // Initialize the iteration argument to the loop initial values.
268   for (auto [arg, operand] :
269        llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
270     setValueMapping(arg, operand.get(), 0);
271   }
272   auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
273   Location loc = forOp.getLoc();
274   SmallVector<Value> predicates(maxStage);
275   for (int64_t i = 0; i < maxStage; i++) {
276     if (dynamicLoop) {
277       Type t = ub.getType();
278       // pred = ub > lb + (i * step)
279       Value iv = rewriter.create<arith::AddIOp>(
280           loc, lb,
281           rewriter.create<arith::MulIOp>(
282               loc, step,
283               rewriter.create<arith::ConstantOp>(
284                   loc, rewriter.getIntegerAttr(t, i))));
285       predicates[i] = rewriter.create<arith::CmpIOp>(
286           loc, arith::CmpIPredicate::slt, iv, ub);
287     }
288 
289     // special handling for induction variable as the increment is implicit.
290     // iv = lb + i * step
291     Type t = lb.getType();
292     Value iv = rewriter.create<arith::AddIOp>(
293         loc, lb,
294         rewriter.create<arith::MulIOp>(
295             loc, step,
296             rewriter.create<arith::ConstantOp>(loc,
297                                                rewriter.getIntegerAttr(t, i))));
298     setValueMapping(forOp.getInductionVar(), iv, i);
299     for (Operation *op : opOrder) {
300       if (stages[op] > i)
301         continue;
302       Operation *newOp =
303           cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
304             auto it = valueMapping.find(newOperand->get());
305             if (it != valueMapping.end()) {
306               Value replacement = it->second[i - stages[op]];
307               newOperand->set(replacement);
308             }
309           });
310       int predicateIdx = i - stages[op];
311       if (predicates[predicateIdx]) {
312         OpBuilder::InsertionGuard insertGuard(rewriter);
313         newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
314         if (newOp == nullptr)
315           return failure();
316       }
317       if (annotateFn)
318         annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
319       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
320         Value source = newOp->getResult(destId);
321         // If the value is a loop carried dependency update the loop argument
322         for (OpOperand &operand : yield->getOpOperands()) {
323           if (operand.get() != op->getResult(destId))
324             continue;
325           if (predicates[predicateIdx] &&
326               !forOp.getResult(operand.getOperandNumber()).use_empty()) {
327             // If the value is used outside the loop, we need to make sure we
328             // return the correct version of it.
329             Value prevValue = valueMapping
330                 [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
331                 [i - stages[op]];
332             source = rewriter.create<arith::SelectOp>(
333                 loc, predicates[predicateIdx], source, prevValue);
334           }
335           setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
336                           source, i - stages[op] + 1);
337         }
338         setValueMapping(op->getResult(destId), newOp->getResult(destId),
339                         i - stages[op]);
340       }
341     }
342   }
343   return success();
344 }
345 
346 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
347 LoopPipelinerInternal::analyzeCrossStageValues() {
348   llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
349   for (Operation *op : opOrder) {
350     unsigned stage = stages[op];
351 
352     auto analyzeOperand = [&](OpOperand &operand) {
353       auto [def, distance] = getDefiningOpAndDistance(operand.get());
354       if (!def)
355         return;
356       auto defStage = stages.find(def);
357       if (defStage == stages.end() || defStage->second == stage ||
358           defStage->second == stage + distance)
359         return;
360       assert(stage > defStage->second);
361       LiverangeInfo &info = crossStageValues[operand.get()];
362       info.defStage = defStage->second;
363       info.lastUseStage = std::max(info.lastUseStage, stage);
364     };
365 
366     for (OpOperand &operand : op->getOpOperands())
367       analyzeOperand(operand);
368     visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
369       analyzeOperand(*operand);
370     });
371   }
372   return crossStageValues;
373 }
374 
375 std::pair<Operation *, int64_t>
376 LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
377   int64_t distance = 0;
378   if (auto arg = dyn_cast<BlockArgument>(value)) {
379     if (arg.getOwner() != forOp.getBody())
380       return {nullptr, 0};
381     // Ignore induction variable.
382     if (arg.getArgNumber() == 0)
383       return {nullptr, 0};
384     distance++;
385     value =
386         forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
387   }
388   Operation *def = value.getDefiningOp();
389   if (!def)
390     return {nullptr, 0};
391   return {def, distance};
392 }
393 
394 scf::ForOp LoopPipelinerInternal::createKernelLoop(
395     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
396         &crossStageValues,
397     RewriterBase &rewriter,
398     llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
399   // Creates the list of initial values associated to values used across
400   // stages. The initial values come from the prologue created above.
401   // Keep track of the kernel argument associated to each version of the
402   // values passed to the kernel.
403   llvm::SmallVector<Value> newLoopArg;
404   // For existing loop argument initialize them with the right version from the
405   // prologue.
406   for (const auto &retVal :
407        llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
408     Operation *def = retVal.value().getDefiningOp();
409     assert(def && "Only support loop carried dependencies of distance of 1 or "
410                   "outside the loop");
411     auto defStage = stages.find(def);
412     if (defStage != stages.end()) {
413       Value valueVersion =
414           valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
415                       [maxStage - defStage->second];
416       assert(valueVersion);
417       newLoopArg.push_back(valueVersion);
418     } else
419       newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
420   }
421   for (auto escape : crossStageValues) {
422     LiverangeInfo &info = escape.second;
423     Value value = escape.first;
424     for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
425          stageIdx++) {
426       Value valueVersion =
427           valueMapping[value][maxStage - info.lastUseStage + stageIdx];
428       assert(valueVersion);
429       newLoopArg.push_back(valueVersion);
430       loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
431                                            stageIdx)] = newLoopArg.size() - 1;
432     }
433   }
434 
435   // Create the new kernel loop. When we peel the epilgue we need to peel
436   // `numStages - 1` iterations. Then we adjust the upper bound to remove those
437   // iterations.
438   Value newUb = forOp.getUpperBound();
439   if (peelEpilogue) {
440     Type t = ub.getType();
441     Location loc = forOp.getLoc();
442     // newUb = ub - maxStage * step
443     Value maxStageValue = rewriter.create<arith::ConstantOp>(
444         loc, rewriter.getIntegerAttr(t, maxStage));
445     Value maxStageByStep =
446         rewriter.create<arith::MulIOp>(loc, step, maxStageValue);
447     newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep);
448   }
449   auto newForOp =
450       rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
451                                   forOp.getStep(), newLoopArg);
452   // When there are no iter args, the loop body terminator will be created.
453   // Since we always create it below, remove the terminator if it was created.
454   if (!newForOp.getBody()->empty())
455     rewriter.eraseOp(newForOp.getBody()->getTerminator());
456   return newForOp;
457 }
458 
459 LogicalResult LoopPipelinerInternal::createKernel(
460     scf::ForOp newForOp,
461     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
462         &crossStageValues,
463     const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
464     RewriterBase &rewriter) {
465   valueMapping.clear();
466 
467   // Create the kernel, we clone instruction based on the order given by
468   // user and remap operands coming from a previous stages.
469   rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
470   IRMapping mapping;
471   mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
472   for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) {
473     mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
474   }
475   SmallVector<Value> predicates(maxStage + 1, nullptr);
476   if (!peelEpilogue) {
477     // Create a predicate for each stage except the last stage.
478     Location loc = newForOp.getLoc();
479     Type t = ub.getType();
480     for (unsigned i = 0; i < maxStage; i++) {
481       // c = ub - (maxStage - i) * step
482       Value c = rewriter.create<arith::SubIOp>(
483           loc, ub,
484           rewriter.create<arith::MulIOp>(
485               loc, step,
486               rewriter.create<arith::ConstantOp>(
487                   loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
488 
489       Value pred = rewriter.create<arith::CmpIOp>(
490           newForOp.getLoc(), arith::CmpIPredicate::slt,
491           newForOp.getInductionVar(), c);
492       predicates[i] = pred;
493     }
494   }
495   for (Operation *op : opOrder) {
496     int64_t useStage = stages[op];
497     auto *newOp = rewriter.clone(*op, mapping);
498     SmallVector<OpOperand *> operands;
499     // Collect all the operands for the cloned op and its nested ops.
500     op->walk([&operands](Operation *nestedOp) {
501       for (OpOperand &operand : nestedOp->getOpOperands()) {
502         operands.push_back(&operand);
503       }
504     });
505     for (OpOperand *operand : operands) {
506       Operation *nestedNewOp = mapping.lookup(operand->getOwner());
507       // Special case for the induction variable uses. We replace it with a
508       // version incremented based on the stage where it is used.
509       if (operand->get() == forOp.getInductionVar()) {
510         rewriter.setInsertionPoint(newOp);
511 
512         // offset = (maxStage - stages[op]) * step
513         Type t = step.getType();
514         Value offset = rewriter.create<arith::MulIOp>(
515             forOp.getLoc(), step,
516             rewriter.create<arith::ConstantOp>(
517                 forOp.getLoc(),
518                 rewriter.getIntegerAttr(t, maxStage - stages[op])));
519         Value iv = rewriter.create<arith::AddIOp>(
520             forOp.getLoc(), newForOp.getInductionVar(), offset);
521         nestedNewOp->setOperand(operand->getOperandNumber(), iv);
522         rewriter.setInsertionPointAfter(newOp);
523         continue;
524       }
525       Value source = operand->get();
526       auto arg = dyn_cast<BlockArgument>(source);
527       if (arg && arg.getOwner() == forOp.getBody()) {
528         Value ret = forOp.getBody()->getTerminator()->getOperand(
529             arg.getArgNumber() - 1);
530         Operation *dep = ret.getDefiningOp();
531         if (!dep)
532           continue;
533         auto stageDep = stages.find(dep);
534         if (stageDep == stages.end() || stageDep->second == useStage)
535           continue;
536         // If the value is a loop carried value coming from stage N + 1 remap,
537         // it will become a direct use.
538         if (stageDep->second == useStage + 1) {
539           nestedNewOp->setOperand(operand->getOperandNumber(),
540                                   mapping.lookupOrDefault(ret));
541           continue;
542         }
543         source = ret;
544       }
545       // For operands defined in a previous stage we need to remap it to use
546       // the correct region argument. We look for the right version of the
547       // Value based on the stage where it is used.
548       Operation *def = source.getDefiningOp();
549       if (!def)
550         continue;
551       auto stageDef = stages.find(def);
552       if (stageDef == stages.end() || stageDef->second == useStage)
553         continue;
554       auto remap = loopArgMap.find(
555           std::make_pair(operand->get(), useStage - stageDef->second));
556       assert(remap != loopArgMap.end());
557       nestedNewOp->setOperand(operand->getOperandNumber(),
558                               newForOp.getRegionIterArgs()[remap->second]);
559     }
560 
561     if (predicates[useStage]) {
562       OpBuilder::InsertionGuard insertGuard(rewriter);
563       newOp = predicateFn(rewriter, newOp, predicates[useStage]);
564       if (!newOp)
565         return failure();
566       // Remap the results to the new predicated one.
567       for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
568         mapping.map(std::get<0>(values), std::get<1>(values));
569     }
570     if (annotateFn)
571       annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
572   }
573 
574   // Collect the Values that need to be returned by the forOp. For each
575   // value we need to have `LastUseStage - DefStage` number of versions
576   // returned.
577   // We create a mapping between original values and the associated loop
578   // returned values that will be needed by the epilogue.
579   llvm::SmallVector<Value> yieldOperands;
580   for (OpOperand &yieldOperand :
581        forOp.getBody()->getTerminator()->getOpOperands()) {
582     Value source = mapping.lookupOrDefault(yieldOperand.get());
583     // When we don't peel the epilogue and the yield value is used outside the
584     // loop we need to make sure we return the version from numStages -
585     // defStage.
586     if (!peelEpilogue &&
587         !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
588       Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
589       if (def) {
590         auto defStage = stages.find(def);
591         if (defStage != stages.end() && defStage->second < maxStage) {
592           Value pred = predicates[defStage->second];
593           source = rewriter.create<arith::SelectOp>(
594               pred.getLoc(), pred, source,
595               newForOp.getBody()
596                   ->getArguments()[yieldOperand.getOperandNumber() + 1]);
597         }
598       }
599     }
600     yieldOperands.push_back(source);
601   }
602 
603   for (auto &it : crossStageValues) {
604     int64_t version = maxStage - it.second.lastUseStage + 1;
605     unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
606     // add the original version to yield ops.
607     // If there is a live range spanning across more than 2 stages we need to
608     // add extra arg.
609     for (unsigned i = 1; i < numVersionReturned; i++) {
610       setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
611                       version++);
612       yieldOperands.push_back(
613           newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
614                                              newForOp.getNumInductionVars()]);
615     }
616     setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
617                     version++);
618     yieldOperands.push_back(mapping.lookupOrDefault(it.first));
619   }
620   // Map the yield operand to the forOp returned value.
621   for (const auto &retVal :
622        llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
623     Operation *def = retVal.value().getDefiningOp();
624     assert(def && "Only support loop carried dependencies of distance of 1 or "
625                   "defined outside the loop");
626     auto defStage = stages.find(def);
627     if (defStage == stages.end()) {
628       for (unsigned int stage = 1; stage <= maxStage; stage++)
629         setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
630                         retVal.value(), stage);
631     } else if (defStage->second > 0) {
632       setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
633                       newForOp->getResult(retVal.index()),
634                       maxStage - defStage->second + 1);
635     }
636   }
637   rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
638   return success();
639 }
640 
641 LogicalResult
642 LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
643                                     llvm::SmallVector<Value> &returnValues) {
644   Location loc = forOp.getLoc();
645   Type t = lb.getType();
646 
647   // Emit different versions of the induction variable. They will be
648   // removed by dead code if not used.
649 
650   auto createConst = [&](int v) {
651     return rewriter.create<arith::ConstantOp>(loc,
652                                               rewriter.getIntegerAttr(t, v));
653   };
654 
655   // total_iterations = cdiv(range_diff, step);
656   // - range_diff = ub - lb
657   // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
658   Value zero = createConst(0);
659   Value one = createConst(1);
660   Value stepLessZero = rewriter.create<arith::CmpIOp>(
661       loc, arith::CmpIPredicate::slt, step, zero);
662   Value stepDecr =
663       rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
664 
665   Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
666   Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
667   Value rangeDecr =
668       rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
669   Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
670 
671   // If total_iters < max_stage, start the epilogue at zero to match the
672   // ramp-up in the prologue.
673   // start_iter = max(0, total_iters - max_stage)
674   Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
675                                                createConst(maxStage));
676   iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
677 
678   // Capture predicates for dynamic loops.
679   SmallVector<Value> predicates(maxStage + 1);
680 
681   for (int64_t i = 1; i <= maxStage; i++) {
682     // newLastIter = lb + step * iterI
683     Value newlastIter = rewriter.create<arith::AddIOp>(
684         loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
685 
686     setValueMapping(forOp.getInductionVar(), newlastIter, i);
687 
688     // increment to next iterI
689     iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
690 
691     if (dynamicLoop) {
692       // Disable stages when `i` is greater than total_iters.
693       // pred = total_iters >= i
694       predicates[i] = rewriter.create<arith::CmpIOp>(
695           loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
696     }
697   }
698 
699   // Emit `maxStage - 1` epilogue part that includes operations from stages
700   // [i; maxStage].
701   for (int64_t i = 1; i <= maxStage; i++) {
702     SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
703     for (Operation *op : opOrder) {
704       if (stages[op] < i)
705         continue;
706       unsigned currentVersion = maxStage - stages[op] + i;
707       unsigned nextVersion = currentVersion + 1;
708       Operation *newOp =
709           cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
710             auto it = valueMapping.find(newOperand->get());
711             if (it != valueMapping.end()) {
712               Value replacement = it->second[currentVersion];
713               newOperand->set(replacement);
714             }
715           });
716       if (dynamicLoop) {
717         OpBuilder::InsertionGuard insertGuard(rewriter);
718         newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
719         if (!newOp)
720           return failure();
721       }
722       if (annotateFn)
723         annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
724 
725       for (auto [opRes, newRes] :
726            llvm::zip(op->getResults(), newOp->getResults())) {
727         setValueMapping(opRes, newRes, currentVersion);
728         // If the value is a loop carried dependency update the loop argument
729         // mapping and keep track of the last version to replace the original
730         // forOp uses.
731         for (OpOperand &operand :
732              forOp.getBody()->getTerminator()->getOpOperands()) {
733           if (operand.get() != opRes)
734             continue;
735           // If the version is greater than maxStage it means it maps to the
736           // original forOp returned value.
737           unsigned ri = operand.getOperandNumber();
738           returnValues[ri] = newRes;
739           Value mapVal = forOp.getRegionIterArgs()[ri];
740           returnMap[ri] = std::make_pair(mapVal, currentVersion);
741           if (nextVersion <= maxStage)
742             setValueMapping(mapVal, newRes, nextVersion);
743         }
744       }
745     }
746     if (dynamicLoop) {
747       // Select return values from this stage (live outs) based on predication.
748       // If the stage is valid select the peeled value, else use previous stage
749       // value.
750       for (auto pair : llvm::enumerate(returnValues)) {
751         unsigned ri = pair.index();
752         auto [mapVal, currentVersion] = returnMap[ri];
753         if (mapVal) {
754           unsigned nextVersion = currentVersion + 1;
755           Value pred = predicates[currentVersion];
756           Value prevValue = valueMapping[mapVal][currentVersion];
757           auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
758                                                         prevValue);
759           returnValues[ri] = selOp;
760           if (nextVersion <= maxStage)
761             setValueMapping(mapVal, selOp, nextVersion);
762         }
763       }
764     }
765   }
766   return success();
767 }
768 
769 void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
770   auto it = valueMapping.find(key);
771   // If the value is not in the map yet add a vector big enough to store all
772   // versions.
773   if (it == valueMapping.end())
774     it =
775         valueMapping
776             .insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1)))
777             .first;
778   it->second[idx] = el;
779 }
780 
781 } // namespace
782 
783 FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
784                                             const PipeliningOption &options,
785                                             bool *modifiedIR) {
786   if (modifiedIR)
787     *modifiedIR = false;
788   LoopPipelinerInternal pipeliner;
789   if (!pipeliner.initializeLoopInfo(forOp, options))
790     return failure();
791 
792   if (modifiedIR)
793     *modifiedIR = true;
794 
795   // 1. Emit prologue.
796   if (failed(pipeliner.emitPrologue(rewriter)))
797     return failure();
798 
799   // 2. Track values used across stages. When a value cross stages it will
800   // need to be passed as loop iteration arguments.
801   // We first collect the values that are used in a different stage than where
802   // they are defined.
803   llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
804       crossStageValues = pipeliner.analyzeCrossStageValues();
805 
806   // Mapping between original loop values used cross stage and the block
807   // arguments associated after pipelining. A Value may map to several
808   // arguments if its liverange spans across more than 2 stages.
809   llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
810   // 3. Create the new kernel loop and return the block arguments mapping.
811   ForOp newForOp =
812       pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
813   // Create the kernel block, order ops based on user choice and remap
814   // operands.
815   if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
816                                     rewriter)))
817     return failure();
818 
819   llvm::SmallVector<Value> returnValues =
820       newForOp.getResults().take_front(forOp->getNumResults());
821   if (options.peelEpilogue) {
822     // 4. Emit the epilogue after the new forOp.
823     rewriter.setInsertionPointAfter(newForOp);
824     if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
825       return failure();
826   }
827   // 5. Erase the original loop and replace the uses with the epilogue output.
828   if (forOp->getNumResults() > 0)
829     rewriter.replaceOp(forOp, returnValues);
830   else
831     rewriter.eraseOp(forOp);
832 
833   return newForOp;
834 }
835 
836 void mlir::scf::populateSCFLoopPipeliningPatterns(
837     RewritePatternSet &patterns, const PipeliningOption &options) {
838   patterns.add<ForLoopPipeliningPattern>(options, patterns.getContext());
839 }
840