xref: /llvm-project/mlir/lib/Dialect/SCF/IR/SCF.cpp (revision 9d8e634e85ca46fbec07733d3e69d34c0d7814ac)
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/IR/SCF.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Arith/Utils/Utils.h"
12 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Interfaces/FunctionInterfaces.h"
23 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // SCFDialect Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SCFInlinerInterface : public DialectInlinerInterface {
40   using DialectInlinerInterface::DialectInlinerInterface;
41   // We don't have any special restrictions on what can be inlined into
42   // destination regions (e.g. while/conditional bodies). Always allow it.
43   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44                        IRMapping &valueMapping) const final {
45     return true;
46   }
47   // Operations in scf dialect are always legal to inline since they are
48   // pure.
49   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50     return true;
51   }
52   // Handle the given inlined terminator by replacing it with a new operation
53   // as necessary. Required when the region has only one block.
54   void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55     auto retValOp = dyn_cast<scf::YieldOp>(op);
56     if (!retValOp)
57       return;
58 
59     for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60       std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
61     }
62   }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // SCFDialect
68 //===----------------------------------------------------------------------===//
69 
70 void SCFDialect::initialize() {
71   addOperations<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74       >();
75   addInterfaces<SCFInlinerInterface>();
76   declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77                             InParallelOp, ReduceReturnOp>();
78   declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79                             ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80                             ForallOp, InParallelOp, WhileOp, YieldOp>();
81   declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
82 }
83 
84 /// Default callback for IfOp builders. Inserts a yield without arguments.
85 void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
86   builder.create<scf::YieldOp>(loc);
87 }
88 
89 /// Verifies that the first block of the given `region` is terminated by a
90 /// TerminatorTy. Reports errors on the given operation if it is not the case.
91 template <typename TerminatorTy>
92 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
93                                            StringRef errorMessage) {
94   Operation *terminatorOperation = nullptr;
95   if (!region.empty() && !region.front().empty()) {
96     terminatorOperation = &region.front().back();
97     if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
98       return yield;
99   }
100   auto diag = op->emitOpError(errorMessage);
101   if (terminatorOperation)
102     diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
103   return nullptr;
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // ExecuteRegionOp
108 //===----------------------------------------------------------------------===//
109 
110 /// Replaces the given op with the contents of the given single-block region,
111 /// using the operands of the block terminator to replace operation results.
112 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
113                                 Region &region, ValueRange blockArgs = {}) {
114   assert(llvm::hasSingleElement(region) && "expected single-region block");
115   Block *block = &region.front();
116   Operation *terminator = block->getTerminator();
117   ValueRange results = terminator->getOperands();
118   rewriter.inlineBlockBefore(block, op, blockArgs);
119   rewriter.replaceOp(op, results);
120   rewriter.eraseOp(terminator);
121 }
122 
123 ///
124 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
125 ///    block+
126 /// `}`
127 ///
128 /// Example:
129 ///   scf.execute_region -> i32 {
130 ///     %idx = load %rI[%i] : memref<128xi32>
131 ///     return %idx : i32
132 ///   }
133 ///
134 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
135                                    OperationState &result) {
136   if (parser.parseOptionalArrowTypeList(result.types))
137     return failure();
138 
139   // Introduce the body region and parse it.
140   Region *body = result.addRegion();
141   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
142       parser.parseOptionalAttrDict(result.attributes))
143     return failure();
144 
145   return success();
146 }
147 
148 void ExecuteRegionOp::print(OpAsmPrinter &p) {
149   p.printOptionalArrowTypeList(getResultTypes());
150 
151   p << ' ';
152   p.printRegion(getRegion(),
153                 /*printEntryBlockArgs=*/false,
154                 /*printBlockTerminators=*/true);
155 
156   p.printOptionalAttrDict((*this)->getAttrs());
157 }
158 
159 LogicalResult ExecuteRegionOp::verify() {
160   if (getRegion().empty())
161     return emitOpError("region needs to have at least one block");
162   if (getRegion().front().getNumArguments() > 0)
163     return emitOpError("region cannot have any arguments");
164   return success();
165 }
166 
167 // Inline an ExecuteRegionOp if it only contains one block.
168 //     "test.foo"() : () -> ()
169 //      %v = scf.execute_region -> i64 {
170 //        %x = "test.val"() : () -> i64
171 //        scf.yield %x : i64
172 //      }
173 //      "test.bar"(%v) : (i64) -> ()
174 //
175 //  becomes
176 //
177 //     "test.foo"() : () -> ()
178 //     %x = "test.val"() : () -> i64
179 //     "test.bar"(%x) : (i64) -> ()
180 //
181 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
182   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
183 
184   LogicalResult matchAndRewrite(ExecuteRegionOp op,
185                                 PatternRewriter &rewriter) const override {
186     if (!llvm::hasSingleElement(op.getRegion()))
187       return failure();
188     replaceOpWithRegion(rewriter, op, op.getRegion());
189     return success();
190   }
191 };
192 
193 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
194 // TODO generalize the conditions for operations which can be inlined into.
195 // func @func_execute_region_elim() {
196 //     "test.foo"() : () -> ()
197 //     %v = scf.execute_region -> i64 {
198 //       %c = "test.cmp"() : () -> i1
199 //       cf.cond_br %c, ^bb2, ^bb3
200 //     ^bb2:
201 //       %x = "test.val1"() : () -> i64
202 //       cf.br ^bb4(%x : i64)
203 //     ^bb3:
204 //       %y = "test.val2"() : () -> i64
205 //       cf.br ^bb4(%y : i64)
206 //     ^bb4(%z : i64):
207 //       scf.yield %z : i64
208 //     }
209 //     "test.bar"(%v) : (i64) -> ()
210 //   return
211 // }
212 //
213 //  becomes
214 //
215 // func @func_execute_region_elim() {
216 //    "test.foo"() : () -> ()
217 //    %c = "test.cmp"() : () -> i1
218 //    cf.cond_br %c, ^bb1, ^bb2
219 //  ^bb1:  // pred: ^bb0
220 //    %x = "test.val1"() : () -> i64
221 //    cf.br ^bb3(%x : i64)
222 //  ^bb2:  // pred: ^bb0
223 //    %y = "test.val2"() : () -> i64
224 //    cf.br ^bb3(%y : i64)
225 //  ^bb3(%z: i64):  // 2 preds: ^bb1, ^bb2
226 //    "test.bar"(%z) : (i64) -> ()
227 //    return
228 //  }
229 //
230 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
231   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
232 
233   LogicalResult matchAndRewrite(ExecuteRegionOp op,
234                                 PatternRewriter &rewriter) const override {
235     if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
236       return failure();
237 
238     Block *prevBlock = op->getBlock();
239     Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
240     rewriter.setInsertionPointToEnd(prevBlock);
241 
242     rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
243 
244     for (Block &blk : op.getRegion()) {
245       if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
246         rewriter.setInsertionPoint(yieldOp);
247         rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248                                       yieldOp.getResults());
249         rewriter.eraseOp(yieldOp);
250       }
251     }
252 
253     rewriter.inlineRegionBefore(op.getRegion(), postBlock);
254     SmallVector<Value> blockArgs;
255 
256     for (auto res : op.getResults())
257       blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
258 
259     rewriter.replaceOp(op, blockArgs);
260     return success();
261   }
262 };
263 
264 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
265                                                   MLIRContext *context) {
266   results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
267 }
268 
269 void ExecuteRegionOp::getSuccessorRegions(
270     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
271   // If the predecessor is the ExecuteRegionOp, branch into the body.
272   if (point.isParent()) {
273     regions.push_back(RegionSuccessor(&getRegion()));
274     return;
275   }
276 
277   // Otherwise, the region branches back to the parent operation.
278   regions.push_back(RegionSuccessor(getResults()));
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // ConditionOp
283 //===----------------------------------------------------------------------===//
284 
285 MutableOperandRange
286 ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
287   assert((point.isParent() || point == getParentOp().getAfter()) &&
288          "condition op can only exit the loop or branch to the after"
289          "region");
290   // Pass all operands except the condition to the successor region.
291   return getArgsMutable();
292 }
293 
294 void ConditionOp::getSuccessorRegions(
295     ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
296   FoldAdaptor adaptor(operands, *this);
297 
298   WhileOp whileOp = getParentOp();
299 
300   // Condition can either lead to the after region or back to the parent op
301   // depending on whether the condition is true or not.
302   auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303   if (!boolAttr || boolAttr.getValue())
304     regions.emplace_back(&whileOp.getAfter(),
305                          whileOp.getAfter().getArguments());
306   if (!boolAttr || !boolAttr.getValue())
307     regions.emplace_back(whileOp.getResults());
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // ForOp
312 //===----------------------------------------------------------------------===//
313 
314 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
315                   Value ub, Value step, ValueRange initArgs,
316                   BodyBuilderFn bodyBuilder) {
317   OpBuilder::InsertionGuard guard(builder);
318 
319   result.addOperands({lb, ub, step});
320   result.addOperands(initArgs);
321   for (Value v : initArgs)
322     result.addTypes(v.getType());
323   Type t = lb.getType();
324   Region *bodyRegion = result.addRegion();
325   Block *bodyBlock = builder.createBlock(bodyRegion);
326   bodyBlock->addArgument(t, result.location);
327   for (Value v : initArgs)
328     bodyBlock->addArgument(v.getType(), v.getLoc());
329 
330   // Create the default terminator if the builder is not provided and if the
331   // iteration arguments are not provided. Otherwise, leave this to the caller
332   // because we don't know which values to return from the loop.
333   if (initArgs.empty() && !bodyBuilder) {
334     ForOp::ensureTerminator(*bodyRegion, builder, result.location);
335   } else if (bodyBuilder) {
336     OpBuilder::InsertionGuard guard(builder);
337     builder.setInsertionPointToStart(bodyBlock);
338     bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
339                 bodyBlock->getArguments().drop_front());
340   }
341 }
342 
343 LogicalResult ForOp::verify() {
344   // Check that the number of init args and op results is the same.
345   if (getInitArgs().size() != getNumResults())
346     return emitOpError(
347         "mismatch in number of loop-carried values and defined values");
348 
349   return success();
350 }
351 
352 LogicalResult ForOp::verifyRegions() {
353   // Check that the body defines as single block argument for the induction
354   // variable.
355   if (getInductionVar().getType() != getLowerBound().getType())
356     return emitOpError(
357         "expected induction variable to be same type as bounds and step");
358 
359   if (getNumRegionIterArgs() != getNumResults())
360     return emitOpError(
361         "mismatch in number of basic block args and defined values");
362 
363   auto initArgs = getInitArgs();
364   auto iterArgs = getRegionIterArgs();
365   auto opResults = getResults();
366   unsigned i = 0;
367   for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
368     if (std::get<0>(e).getType() != std::get<2>(e).getType())
369       return emitOpError() << "types mismatch between " << i
370                            << "th iter operand and defined value";
371     if (std::get<1>(e).getType() != std::get<2>(e).getType())
372       return emitOpError() << "types mismatch between " << i
373                            << "th iter region arg and defined value";
374 
375     ++i;
376   }
377   return success();
378 }
379 
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
381   return SmallVector<Value>{getInductionVar()};
382 }
383 
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
385   return SmallVector<OpFoldResult>{OpFoldResult(getLowerBound())};
386 }
387 
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
389   return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
390 }
391 
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
393   return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())};
394 }
395 
396 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
397 
398 /// Promotes the loop body of a forOp to its containing block if the forOp
399 /// it can be determined that the loop has a single iteration.
400 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
401   std::optional<int64_t> tripCount =
402       constantTripCount(getLowerBound(), getUpperBound(), getStep());
403   if (!tripCount.has_value() || tripCount != 1)
404     return failure();
405 
406   // Replace all results with the yielded values.
407   auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
408   rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
409 
410   // Replace block arguments with lower bound (replacement for IV) and
411   // iter_args.
412   SmallVector<Value> bbArgReplacements;
413   bbArgReplacements.push_back(getLowerBound());
414   llvm::append_range(bbArgReplacements, getInitArgs());
415 
416   // Move the loop body operations to the loop's containing block.
417   rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
418                              getOperation()->getIterator(), bbArgReplacements);
419 
420   // Erase the old terminator and the loop.
421   rewriter.eraseOp(yieldOp);
422   rewriter.eraseOp(*this);
423 
424   return success();
425 }
426 
427 /// Prints the initialization list in the form of
428 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
429 /// where 'inner' values are assumed to be region arguments and 'outer' values
430 /// are regular SSA values.
431 static void printInitializationList(OpAsmPrinter &p,
432                                     Block::BlockArgListType blocksArgs,
433                                     ValueRange initializers,
434                                     StringRef prefix = "") {
435   assert(blocksArgs.size() == initializers.size() &&
436          "expected same length of arguments and initializers");
437   if (initializers.empty())
438     return;
439 
440   p << prefix << '(';
441   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
442     p << std::get<0>(it) << " = " << std::get<1>(it);
443   });
444   p << ")";
445 }
446 
447 void ForOp::print(OpAsmPrinter &p) {
448   p << " " << getInductionVar() << " = " << getLowerBound() << " to "
449     << getUpperBound() << " step " << getStep();
450 
451   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
452   if (!getInitArgs().empty())
453     p << " -> (" << getInitArgs().getTypes() << ')';
454   p << ' ';
455   if (Type t = getInductionVar().getType(); !t.isIndex())
456     p << " : " << t << ' ';
457   p.printRegion(getRegion(),
458                 /*printEntryBlockArgs=*/false,
459                 /*printBlockTerminators=*/!getInitArgs().empty());
460   p.printOptionalAttrDict((*this)->getAttrs());
461 }
462 
463 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
464   auto &builder = parser.getBuilder();
465   Type type;
466 
467   OpAsmParser::Argument inductionVariable;
468   OpAsmParser::UnresolvedOperand lb, ub, step;
469 
470   // Parse the induction variable followed by '='.
471   if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
472       // Parse loop bounds.
473       parser.parseOperand(lb) || parser.parseKeyword("to") ||
474       parser.parseOperand(ub) || parser.parseKeyword("step") ||
475       parser.parseOperand(step))
476     return failure();
477 
478   // Parse the optional initial iteration arguments.
479   SmallVector<OpAsmParser::Argument, 4> regionArgs;
480   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
481   regionArgs.push_back(inductionVariable);
482 
483   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
484   if (hasIterArgs) {
485     // Parse assignment list and results type list.
486     if (parser.parseAssignmentList(regionArgs, operands) ||
487         parser.parseArrowTypeList(result.types))
488       return failure();
489   }
490 
491   if (regionArgs.size() != result.types.size() + 1)
492     return parser.emitError(
493         parser.getNameLoc(),
494         "mismatch in number of loop-carried values and defined values");
495 
496   // Parse optional type, else assume Index.
497   if (parser.parseOptionalColon())
498     type = builder.getIndexType();
499   else if (parser.parseType(type))
500     return failure();
501 
502   // Resolve input operands.
503   regionArgs.front().type = type;
504   if (parser.resolveOperand(lb, type, result.operands) ||
505       parser.resolveOperand(ub, type, result.operands) ||
506       parser.resolveOperand(step, type, result.operands))
507     return failure();
508   if (hasIterArgs) {
509     for (auto argOperandType :
510          llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
511       Type type = std::get<2>(argOperandType);
512       std::get<0>(argOperandType).type = type;
513       if (parser.resolveOperand(std::get<1>(argOperandType), type,
514                                 result.operands))
515         return failure();
516     }
517   }
518 
519   // Parse the body region.
520   Region *body = result.addRegion();
521   if (parser.parseRegion(*body, regionArgs))
522     return failure();
523 
524   ForOp::ensureTerminator(*body, builder, result.location);
525 
526   // Parse the optional attribute list.
527   if (parser.parseOptionalAttrDict(result.attributes))
528     return failure();
529 
530   return success();
531 }
532 
533 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
534 
535 Block::BlockArgListType ForOp::getRegionIterArgs() {
536   return getBody()->getArguments().drop_front(getNumInductionVars());
537 }
538 
539 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
540   return getInitArgsMutable();
541 }
542 
543 FailureOr<LoopLikeOpInterface>
544 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
545                                    ValueRange newInitOperands,
546                                    bool replaceInitOperandUsesInLoop,
547                                    const NewYieldValuesFn &newYieldValuesFn) {
548   // Create a new loop before the existing one, with the extra operands.
549   OpBuilder::InsertionGuard g(rewriter);
550   rewriter.setInsertionPoint(getOperation());
551   auto inits = llvm::to_vector(getInitArgs());
552   inits.append(newInitOperands.begin(), newInitOperands.end());
553   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
554       getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
555       [](OpBuilder &, Location, Value, ValueRange) {});
556   newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
557 
558   // Generate the new yield values and append them to the scf.yield operation.
559   auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
560   ArrayRef<BlockArgument> newIterArgs =
561       newLoop.getBody()->getArguments().take_back(newInitOperands.size());
562   {
563     OpBuilder::InsertionGuard g(rewriter);
564     rewriter.setInsertionPoint(yieldOp);
565     SmallVector<Value> newYieldedValues =
566         newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567     assert(newInitOperands.size() == newYieldedValues.size() &&
568            "expected as many new yield values as new iter operands");
569     rewriter.modifyOpInPlace(yieldOp, [&]() {
570       yieldOp.getResultsMutable().append(newYieldedValues);
571     });
572   }
573 
574   // Move the loop body to the new op.
575   rewriter.mergeBlocks(getBody(), newLoop.getBody(),
576                        newLoop.getBody()->getArguments().take_front(
577                            getBody()->getNumArguments()));
578 
579   if (replaceInitOperandUsesInLoop) {
580     // Replace all uses of `newInitOperands` with the corresponding basic block
581     // arguments.
582     for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
583       rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
584                                  [&](OpOperand &use) {
585                                    Operation *user = use.getOwner();
586                                    return newLoop->isProperAncestor(user);
587                                  });
588     }
589   }
590 
591   // Replace the old loop.
592   rewriter.replaceOp(getOperation(),
593                      newLoop->getResults().take_front(getNumResults()));
594   return cast<LoopLikeOpInterface>(newLoop.getOperation());
595 }
596 
597 ForOp mlir::scf::getForInductionVarOwner(Value val) {
598   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
599   if (!ivArg)
600     return ForOp();
601   assert(ivArg.getOwner() && "unlinked block argument");
602   auto *containingOp = ivArg.getOwner()->getParentOp();
603   return dyn_cast_or_null<ForOp>(containingOp);
604 }
605 
606 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
607   return getInitArgs();
608 }
609 
610 void ForOp::getSuccessorRegions(RegionBranchPoint point,
611                                 SmallVectorImpl<RegionSuccessor> &regions) {
612   // Both the operation itself and the region may be branching into the body or
613   // back into the operation itself. It is possible for loop not to enter the
614   // body.
615   regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
616   regions.push_back(RegionSuccessor(getResults()));
617 }
618 
619 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
620 
621 /// Promotes the loop body of a forallOp to its containing block if it can be
622 /// determined that the loop has a single iteration.
623 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
624   for (auto [lb, ub, step] :
625        llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
626     auto tripCount = constantTripCount(lb, ub, step);
627     if (!tripCount.has_value() || *tripCount != 1)
628       return failure();
629   }
630 
631   promote(rewriter, *this);
632   return success();
633 }
634 
635 Block::BlockArgListType ForallOp::getRegionIterArgs() {
636   return getBody()->getArguments().drop_front(getRank());
637 }
638 
639 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
640   return getOutputsMutable();
641 }
642 
643 /// Promotes the loop body of a scf::ForallOp to its containing block.
644 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
645   OpBuilder::InsertionGuard g(rewriter);
646   scf::InParallelOp terminator = forallOp.getTerminator();
647 
648   // Replace block arguments with lower bounds (replacements for IVs) and
649   // outputs.
650   SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
651   bbArgReplacements.append(forallOp.getOutputs().begin(),
652                            forallOp.getOutputs().end());
653 
654   // Move the loop body operations to the loop's containing block.
655   rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
656                              forallOp->getIterator(), bbArgReplacements);
657 
658   // Replace the terminator with tensor.insert_slice ops.
659   rewriter.setInsertionPointAfter(forallOp);
660   SmallVector<Value> results;
661   results.reserve(forallOp.getResults().size());
662   for (auto &yieldingOp : terminator.getYieldingOps()) {
663     auto parallelInsertSliceOp =
664         cast<tensor::ParallelInsertSliceOp>(yieldingOp);
665 
666     Value dst = parallelInsertSliceOp.getDest();
667     Value src = parallelInsertSliceOp.getSource();
668     if (llvm::isa<TensorType>(src.getType())) {
669       results.push_back(rewriter.create<tensor::InsertSliceOp>(
670           forallOp.getLoc(), dst.getType(), src, dst,
671           parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672           parallelInsertSliceOp.getStrides(),
673           parallelInsertSliceOp.getStaticOffsets(),
674           parallelInsertSliceOp.getStaticSizes(),
675           parallelInsertSliceOp.getStaticStrides()));
676     } else {
677       llvm_unreachable("unsupported terminator");
678     }
679   }
680   rewriter.replaceAllUsesWith(forallOp.getResults(), results);
681 
682   // Erase the old terminator and the loop.
683   rewriter.eraseOp(terminator);
684   rewriter.eraseOp(forallOp);
685 }
686 
687 LoopNest mlir::scf::buildLoopNest(
688     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
689     ValueRange steps, ValueRange iterArgs,
690     function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
691         bodyBuilder) {
692   assert(lbs.size() == ubs.size() &&
693          "expected the same number of lower and upper bounds");
694   assert(lbs.size() == steps.size() &&
695          "expected the same number of lower bounds and steps");
696 
697   // If there are no bounds, call the body-building function and return early.
698   if (lbs.empty()) {
699     ValueVector results =
700         bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
701                     : ValueVector();
702     assert(results.size() == iterArgs.size() &&
703            "loop nest body must return as many values as loop has iteration "
704            "arguments");
705     return LoopNest{{}, std::move(results)};
706   }
707 
708   // First, create the loop structure iteratively using the body-builder
709   // callback of `ForOp::build`. Do not create `YieldOp`s yet.
710   OpBuilder::InsertionGuard guard(builder);
711   SmallVector<scf::ForOp, 4> loops;
712   SmallVector<Value, 4> ivs;
713   loops.reserve(lbs.size());
714   ivs.reserve(lbs.size());
715   ValueRange currentIterArgs = iterArgs;
716   Location currentLoc = loc;
717   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
718     auto loop = builder.create<scf::ForOp>(
719         currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
720         [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
721             ValueRange args) {
722           ivs.push_back(iv);
723           // It is safe to store ValueRange args because it points to block
724           // arguments of a loop operation that we also own.
725           currentIterArgs = args;
726           currentLoc = nestedLoc;
727         });
728     // Set the builder to point to the body of the newly created loop. We don't
729     // do this in the callback because the builder is reset when the callback
730     // returns.
731     builder.setInsertionPointToStart(loop.getBody());
732     loops.push_back(loop);
733   }
734 
735   // For all loops but the innermost, yield the results of the nested loop.
736   for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
737     builder.setInsertionPointToEnd(loops[i].getBody());
738     builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
739   }
740 
741   // In the body of the innermost loop, call the body building function if any
742   // and yield its results.
743   builder.setInsertionPointToStart(loops.back().getBody());
744   ValueVector results = bodyBuilder
745                             ? bodyBuilder(builder, currentLoc, ivs,
746                                           loops.back().getRegionIterArgs())
747                             : ValueVector();
748   assert(results.size() == iterArgs.size() &&
749          "loop nest body must return as many values as loop has iteration "
750          "arguments");
751   builder.setInsertionPointToEnd(loops.back().getBody());
752   builder.create<scf::YieldOp>(loc, results);
753 
754   // Return the loops.
755   ValueVector nestResults;
756   llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757   return LoopNest{std::move(loops), std::move(nestResults)};
758 }
759 
760 LoopNest mlir::scf::buildLoopNest(
761     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
762     ValueRange steps,
763     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
764   // Delegate to the main function by wrapping the body builder.
765   return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
766                        [&bodyBuilder](OpBuilder &nestedBuilder,
767                                       Location nestedLoc, ValueRange ivs,
768                                       ValueRange) -> ValueVector {
769                          if (bodyBuilder)
770                            bodyBuilder(nestedBuilder, nestedLoc, ivs);
771                          return {};
772                        });
773 }
774 
775 SmallVector<Value>
776 mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
777                                       OpOperand &operand, Value replacement,
778                                       const ValueTypeCastFnTy &castFn) {
779   assert(operand.getOwner() == forOp);
780   Type oldType = operand.get().getType(), newType = replacement.getType();
781 
782   // 1. Create new iter operands, exactly 1 is replaced.
783   assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
784          "expected an iter OpOperand");
785   assert(operand.get().getType() != replacement.getType() &&
786          "Expected a different type");
787   SmallVector<Value> newIterOperands;
788   for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
789     if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
790       newIterOperands.push_back(replacement);
791       continue;
792     }
793     newIterOperands.push_back(opOperand.get());
794   }
795 
796   // 2. Create the new forOp shell.
797   scf::ForOp newForOp = rewriter.create<scf::ForOp>(
798       forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
799       forOp.getStep(), newIterOperands);
800   newForOp->setAttrs(forOp->getAttrs());
801   Block &newBlock = newForOp.getRegion().front();
802   SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
803                                              newBlock.getArguments().end());
804 
805   // 3. Inject an incoming cast op at the beginning of the block for the bbArg
806   // corresponding to the `replacement` value.
807   OpBuilder::InsertionGuard g(rewriter);
808   rewriter.setInsertionPointToStart(&newBlock);
809   BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
810       &newForOp->getOpOperand(operand.getOperandNumber()));
811   Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
812   newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
813 
814   // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
815   Block &oldBlock = forOp.getRegion().front();
816   rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
817 
818   // 5. Inject an outgoing cast op at the end of the block and yield it instead.
819   auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
820   rewriter.setInsertionPoint(clonedYieldOp);
821   unsigned yieldIdx =
822       newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
823   Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
824                          clonedYieldOp.getOperand(yieldIdx));
825   SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
826   newYieldOperands[yieldIdx] = castOut;
827   rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
828   rewriter.eraseOp(clonedYieldOp);
829 
830   // 6. Inject an outgoing cast op after the forOp.
831   rewriter.setInsertionPointAfter(newForOp);
832   SmallVector<Value> newResults = newForOp.getResults();
833   newResults[yieldIdx] =
834       castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
835 
836   return newResults;
837 }
838 
839 namespace {
840 // Fold away ForOp iter arguments when:
841 // 1) The op yields the iter arguments.
842 // 2) The argument's corresponding outer region iterators (inputs) are yielded.
843 // 3) The iter arguments have no use and the corresponding (operation) results
844 // have no use.
845 //
846 // These arguments must be defined outside of
847 // the ForOp region and can just be forwarded after simplifying the op inits,
848 // yields and returns.
849 //
850 // The implementation uses `inlineBlockBefore` to steal the content of the
851 // original ForOp and avoid cloning.
852 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
853   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
854 
855   LogicalResult matchAndRewrite(scf::ForOp forOp,
856                                 PatternRewriter &rewriter) const final {
857     bool canonicalize = false;
858 
859     // An internal flat vector of block transfer
860     // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
861     // transformed block argument mappings. This plays the role of a
862     // IRMapping for the particular use case of calling into
863     // `inlineBlockBefore`.
864     int64_t numResults = forOp.getNumResults();
865     SmallVector<bool, 4> keepMask;
866     keepMask.reserve(numResults);
867     SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
868         newResultValues;
869     newBlockTransferArgs.reserve(1 + numResults);
870     newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
871     newIterArgs.reserve(forOp.getInitArgs().size());
872     newYieldValues.reserve(numResults);
873     newResultValues.reserve(numResults);
874     for (auto [init, arg, result, yielded] :
875          llvm::zip(forOp.getInitArgs(),       // iter from outside
876                    forOp.getRegionIterArgs(), // iter inside region
877                    forOp.getResults(),        // op results
878                    forOp.getYieldedValues()   // iter yield
879                    )) {
880       // Forwarded is `true` when:
881       // 1) The region `iter` argument is yielded.
882       // 2) The region `iter` argument the corresponding input is yielded.
883       // 3) The region `iter` argument has no use, and the corresponding op
884       // result has no use.
885       bool forwarded = (arg == yielded) || (init == yielded) ||
886                        (arg.use_empty() && result.use_empty());
887       keepMask.push_back(!forwarded);
888       canonicalize |= forwarded;
889       if (forwarded) {
890         newBlockTransferArgs.push_back(init);
891         newResultValues.push_back(init);
892         continue;
893       }
894       newIterArgs.push_back(init);
895       newYieldValues.push_back(yielded);
896       newBlockTransferArgs.push_back(Value()); // placeholder with null value
897       newResultValues.push_back(Value());      // placeholder with null value
898     }
899 
900     if (!canonicalize)
901       return failure();
902 
903     scf::ForOp newForOp = rewriter.create<scf::ForOp>(
904         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
905         forOp.getStep(), newIterArgs);
906     newForOp->setAttrs(forOp->getAttrs());
907     Block &newBlock = newForOp.getRegion().front();
908 
909     // Replace the null placeholders with newly constructed values.
910     newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
911     for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
912          idx != e; ++idx) {
913       Value &blockTransferArg = newBlockTransferArgs[1 + idx];
914       Value &newResultVal = newResultValues[idx];
915       assert((blockTransferArg && newResultVal) ||
916              (!blockTransferArg && !newResultVal));
917       if (!blockTransferArg) {
918         blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
919         newResultVal = newForOp.getResult(collapsedIdx++);
920       }
921     }
922 
923     Block &oldBlock = forOp.getRegion().front();
924     assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
925            "unexpected argument size mismatch");
926 
927     // No results case: the scf::ForOp builder already created a zero
928     // result terminator. Merge before this terminator and just get rid of the
929     // original terminator that has been merged in.
930     if (newIterArgs.empty()) {
931       auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
932       rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
933       rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
934       rewriter.replaceOp(forOp, newResultValues);
935       return success();
936     }
937 
938     // No terminator case: merge and rewrite the merged terminator.
939     auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
940       OpBuilder::InsertionGuard g(rewriter);
941       rewriter.setInsertionPoint(mergedTerminator);
942       SmallVector<Value, 4> filteredOperands;
943       filteredOperands.reserve(newResultValues.size());
944       for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
945         if (keepMask[idx])
946           filteredOperands.push_back(mergedTerminator.getOperand(idx));
947       rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
948                                     filteredOperands);
949     };
950 
951     rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
952     auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
953     cloneFilteredTerminator(mergedYieldOp);
954     rewriter.eraseOp(mergedYieldOp);
955     rewriter.replaceOp(forOp, newResultValues);
956     return success();
957   }
958 };
959 
960 /// Util function that tries to compute a constant diff between u and l.
961 /// Returns std::nullopt when the difference between two AffineValueMap is
962 /// dynamic.
963 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
964   IntegerAttr clb, cub;
965   if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
966     llvm::APInt lbValue = clb.getValue();
967     llvm::APInt ubValue = cub.getValue();
968     return (ubValue - lbValue).getSExtValue();
969   }
970 
971   // Else a simple pattern match for x + c or c + x
972   llvm::APInt diff;
973   if (matchPattern(
974           u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
975       matchPattern(
976           u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
977     return diff.getSExtValue();
978   return std::nullopt;
979 }
980 
981 /// Rewriting pattern that erases loops that are known not to iterate, replaces
982 /// single-iteration loops with their bodies, and removes empty loops that
983 /// iterate at least once and only return values defined outside of the loop.
984 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
985   using OpRewritePattern<ForOp>::OpRewritePattern;
986 
987   LogicalResult matchAndRewrite(ForOp op,
988                                 PatternRewriter &rewriter) const override {
989     // If the upper bound is the same as the lower bound, the loop does not
990     // iterate, just remove it.
991     if (op.getLowerBound() == op.getUpperBound()) {
992       rewriter.replaceOp(op, op.getInitArgs());
993       return success();
994     }
995 
996     std::optional<int64_t> diff =
997         computeConstDiff(op.getLowerBound(), op.getUpperBound());
998     if (!diff)
999       return failure();
1000 
1001     // If the loop is known to have 0 iterations, remove it.
1002     if (*diff <= 0) {
1003       rewriter.replaceOp(op, op.getInitArgs());
1004       return success();
1005     }
1006 
1007     std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1008     if (!maybeStepValue)
1009       return failure();
1010 
1011     // If the loop is known to have 1 iteration, inline its body and remove the
1012     // loop.
1013     llvm::APInt stepValue = *maybeStepValue;
1014     if (stepValue.sge(*diff)) {
1015       SmallVector<Value, 4> blockArgs;
1016       blockArgs.reserve(op.getInitArgs().size() + 1);
1017       blockArgs.push_back(op.getLowerBound());
1018       llvm::append_range(blockArgs, op.getInitArgs());
1019       replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1020       return success();
1021     }
1022 
1023     // Now we are left with loops that have more than 1 iterations.
1024     Block &block = op.getRegion().front();
1025     if (!llvm::hasSingleElement(block))
1026       return failure();
1027     // If the loop is empty, iterates at least once, and only returns values
1028     // defined outside of the loop, remove it and replace it with yield values.
1029     if (llvm::any_of(op.getYieldedValues(),
1030                      [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1031       return failure();
1032     rewriter.replaceOp(op, op.getYieldedValues());
1033     return success();
1034   }
1035 };
1036 
1037 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1038 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1039 ///
1040 /// ```
1041 ///   %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1042 ///   %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1043 ///      -> (tensor<?x?xf32>) {
1044 ///     %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1045 ///     scf.yield %2 : tensor<?x?xf32>
1046 ///   }
1047 ///   use_of(%1)
1048 /// ```
1049 ///
1050 /// folds into:
1051 ///
1052 /// ```
1053 ///   %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1054 ///       -> (tensor<32x1024xf32>) {
1055 ///     %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1056 ///     %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1057 ///     %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1058 ///     scf.yield %4 : tensor<32x1024xf32>
1059 ///   }
1060 ///   %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1061 ///   use_of(%1)
1062 /// ```
1063 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1064   using OpRewritePattern<ForOp>::OpRewritePattern;
1065 
1066   LogicalResult matchAndRewrite(ForOp op,
1067                                 PatternRewriter &rewriter) const override {
1068     for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1069       OpOperand &iterOpOperand = std::get<0>(it);
1070       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1071       if (!incomingCast ||
1072           incomingCast.getSource().getType() == incomingCast.getType())
1073         continue;
1074       // If the dest type of the cast does not preserve static information in
1075       // the source type.
1076       if (!tensor::preservesStaticInformation(
1077               incomingCast.getDest().getType(),
1078               incomingCast.getSource().getType()))
1079         continue;
1080       if (!std::get<1>(it).hasOneUse())
1081         continue;
1082 
1083       // Create a new ForOp with that iter operand replaced.
1084       rewriter.replaceOp(
1085           op, replaceAndCastForOpIterArg(
1086                   rewriter, op, iterOpOperand, incomingCast.getSource(),
1087                   [](OpBuilder &b, Location loc, Type type, Value source) {
1088                     return b.create<tensor::CastOp>(loc, type, source);
1089                   }));
1090       return success();
1091     }
1092     return failure();
1093   }
1094 };
1095 
1096 } // namespace
1097 
1098 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1099                                         MLIRContext *context) {
1100   results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1101       context);
1102 }
1103 
1104 std::optional<APInt> ForOp::getConstantStep() {
1105   IntegerAttr step;
1106   if (matchPattern(getStep(), m_Constant(&step)))
1107     return step.getValue();
1108   return {};
1109 }
1110 
1111 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1112   return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1113 }
1114 
1115 Speculation::Speculatability ForOp::getSpeculatability() {
1116   // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1117   // and End.
1118   if (auto constantStep = getConstantStep())
1119     if (*constantStep == 1)
1120       return Speculation::RecursivelySpeculatable;
1121 
1122   // For Step != 1, the loop may not terminate.  We can add more smarts here if
1123   // needed.
1124   return Speculation::NotSpeculatable;
1125 }
1126 
1127 //===----------------------------------------------------------------------===//
1128 // ForallOp
1129 //===----------------------------------------------------------------------===//
1130 
1131 LogicalResult ForallOp::verify() {
1132   unsigned numLoops = getRank();
1133   // Check number of outputs.
1134   if (getNumResults() != getOutputs().size())
1135     return emitOpError("produces ")
1136            << getNumResults() << " results, but has only "
1137            << getOutputs().size() << " outputs";
1138 
1139   // Check that the body defines block arguments for thread indices and outputs.
1140   auto *body = getBody();
1141   if (body->getNumArguments() != numLoops + getOutputs().size())
1142     return emitOpError("region expects ") << numLoops << " arguments";
1143   for (int64_t i = 0; i < numLoops; ++i)
1144     if (!body->getArgument(i).getType().isIndex())
1145       return emitOpError("expects ")
1146              << i << "-th block argument to be an index";
1147   for (unsigned i = 0; i < getOutputs().size(); ++i)
1148     if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1149       return emitOpError("type mismatch between ")
1150              << i << "-th output and corresponding block argument";
1151   if (getMapping().has_value() && !getMapping()->empty()) {
1152     if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1153       return emitOpError() << "mapping attribute size must match op rank";
1154     for (auto map : getMapping()->getValue()) {
1155       if (!isa<DeviceMappingAttrInterface>(map))
1156         return emitOpError()
1157                << getMappingAttrName() << " is not device mapping attribute";
1158     }
1159   }
1160 
1161   // Verify mixed static/dynamic control variables.
1162   Operation *op = getOperation();
1163   if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1164                                             getStaticLowerBound(),
1165                                             getDynamicLowerBound())))
1166     return failure();
1167   if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1168                                             getStaticUpperBound(),
1169                                             getDynamicUpperBound())))
1170     return failure();
1171   if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1172                                             getStaticStep(), getDynamicStep())))
1173     return failure();
1174 
1175   return success();
1176 }
1177 
1178 void ForallOp::print(OpAsmPrinter &p) {
1179   Operation *op = getOperation();
1180   p << " (" << getInductionVars();
1181   if (isNormalized()) {
1182     p << ") in ";
1183     printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1184                           /*valueTypes=*/{}, /*scalables=*/{},
1185                           OpAsmParser::Delimiter::Paren);
1186   } else {
1187     p << ") = ";
1188     printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1189                           /*valueTypes=*/{}, /*scalables=*/{},
1190                           OpAsmParser::Delimiter::Paren);
1191     p << " to ";
1192     printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1193                           /*valueTypes=*/{}, /*scalables=*/{},
1194                           OpAsmParser::Delimiter::Paren);
1195     p << " step ";
1196     printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1197                           /*valueTypes=*/{}, /*scalables=*/{},
1198                           OpAsmParser::Delimiter::Paren);
1199   }
1200   printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1201   p << " ";
1202   if (!getRegionOutArgs().empty())
1203     p << "-> (" << getResultTypes() << ") ";
1204   p.printRegion(getRegion(),
1205                 /*printEntryBlockArgs=*/false,
1206                 /*printBlockTerminators=*/getNumResults() > 0);
1207   p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1208                                            getStaticLowerBoundAttrName(),
1209                                            getStaticUpperBoundAttrName(),
1210                                            getStaticStepAttrName()});
1211 }
1212 
1213 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1214   OpBuilder b(parser.getContext());
1215   auto indexType = b.getIndexType();
1216 
1217   // Parse an opening `(` followed by thread index variables followed by `)`
1218   // TODO: when we can refer to such "induction variable"-like handles from the
1219   // declarative assembly format, we can implement the parser as a custom hook.
1220   SmallVector<OpAsmParser::Argument, 4> ivs;
1221   if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
1222     return failure();
1223 
1224   DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1225   SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1226       dynamicSteps;
1227   if (succeeded(parser.parseOptionalKeyword("in"))) {
1228     // Parse upper bounds.
1229     if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1230                               /*valueTypes=*/nullptr,
1231                               OpAsmParser::Delimiter::Paren) ||
1232         parser.resolveOperands(dynamicUbs, indexType, result.operands))
1233       return failure();
1234 
1235     unsigned numLoops = ivs.size();
1236     staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1237     staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1238   } else {
1239     // Parse lower bounds.
1240     if (parser.parseEqual() ||
1241         parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1242                               /*valueTypes=*/nullptr,
1243                               OpAsmParser::Delimiter::Paren) ||
1244 
1245         parser.resolveOperands(dynamicLbs, indexType, result.operands))
1246       return failure();
1247 
1248     // Parse upper bounds.
1249     if (parser.parseKeyword("to") ||
1250         parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1251                               /*valueTypes=*/nullptr,
1252                               OpAsmParser::Delimiter::Paren) ||
1253         parser.resolveOperands(dynamicUbs, indexType, result.operands))
1254       return failure();
1255 
1256     // Parse step values.
1257     if (parser.parseKeyword("step") ||
1258         parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1259                               /*valueTypes=*/nullptr,
1260                               OpAsmParser::Delimiter::Paren) ||
1261         parser.resolveOperands(dynamicSteps, indexType, result.operands))
1262       return failure();
1263   }
1264 
1265   // Parse out operands and results.
1266   SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1267   SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1268   SMLoc outOperandsLoc = parser.getCurrentLocation();
1269   if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1270     if (outOperands.size() != result.types.size())
1271       return parser.emitError(outOperandsLoc,
1272                               "mismatch between out operands and types");
1273     if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1274         parser.parseOptionalArrowTypeList(result.types) ||
1275         parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1276                                result.operands))
1277       return failure();
1278   }
1279 
1280   // Parse region.
1281   SmallVector<OpAsmParser::Argument, 4> regionArgs;
1282   std::unique_ptr<Region> region = std::make_unique<Region>();
1283   for (auto &iv : ivs) {
1284     iv.type = b.getIndexType();
1285     regionArgs.push_back(iv);
1286   }
1287   for (const auto &it : llvm::enumerate(regionOutArgs)) {
1288     auto &out = it.value();
1289     out.type = result.types[it.index()];
1290     regionArgs.push_back(out);
1291   }
1292   if (parser.parseRegion(*region, regionArgs))
1293     return failure();
1294 
1295   // Ensure terminator and move region.
1296   ForallOp::ensureTerminator(*region, b, result.location);
1297   result.addRegion(std::move(region));
1298 
1299   // Parse the optional attribute list.
1300   if (parser.parseOptionalAttrDict(result.attributes))
1301     return failure();
1302 
1303   result.addAttribute("staticLowerBound", staticLbs);
1304   result.addAttribute("staticUpperBound", staticUbs);
1305   result.addAttribute("staticStep", staticSteps);
1306   result.addAttribute("operandSegmentSizes",
1307                       parser.getBuilder().getDenseI32ArrayAttr(
1308                           {static_cast<int32_t>(dynamicLbs.size()),
1309                            static_cast<int32_t>(dynamicUbs.size()),
1310                            static_cast<int32_t>(dynamicSteps.size()),
1311                            static_cast<int32_t>(outOperands.size())}));
1312   return success();
1313 }
1314 
1315 // Builder that takes loop bounds.
1316 void ForallOp::build(
1317     mlir::OpBuilder &b, mlir::OperationState &result,
1318     ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1319     ArrayRef<OpFoldResult> steps, ValueRange outputs,
1320     std::optional<ArrayAttr> mapping,
1321     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1322   SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1323   SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1324   dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1325   dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1326   dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1327 
1328   result.addOperands(dynamicLbs);
1329   result.addOperands(dynamicUbs);
1330   result.addOperands(dynamicSteps);
1331   result.addOperands(outputs);
1332   result.addTypes(TypeRange(outputs));
1333 
1334   result.addAttribute(getStaticLowerBoundAttrName(result.name),
1335                       b.getDenseI64ArrayAttr(staticLbs));
1336   result.addAttribute(getStaticUpperBoundAttrName(result.name),
1337                       b.getDenseI64ArrayAttr(staticUbs));
1338   result.addAttribute(getStaticStepAttrName(result.name),
1339                       b.getDenseI64ArrayAttr(staticSteps));
1340   result.addAttribute(
1341       "operandSegmentSizes",
1342       b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1343                               static_cast<int32_t>(dynamicUbs.size()),
1344                               static_cast<int32_t>(dynamicSteps.size()),
1345                               static_cast<int32_t>(outputs.size())}));
1346   if (mapping.has_value()) {
1347     result.addAttribute(ForallOp::getMappingAttrName(result.name),
1348                         mapping.value());
1349   }
1350 
1351   Region *bodyRegion = result.addRegion();
1352   OpBuilder::InsertionGuard g(b);
1353   b.createBlock(bodyRegion);
1354   Block &bodyBlock = bodyRegion->front();
1355 
1356   // Add block arguments for indices and outputs.
1357   bodyBlock.addArguments(
1358       SmallVector<Type>(lbs.size(), b.getIndexType()),
1359       SmallVector<Location>(staticLbs.size(), result.location));
1360   bodyBlock.addArguments(
1361       TypeRange(outputs),
1362       SmallVector<Location>(outputs.size(), result.location));
1363 
1364   b.setInsertionPointToStart(&bodyBlock);
1365   if (!bodyBuilderFn) {
1366     ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1367     return;
1368   }
1369   bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1370 }
1371 
1372 // Builder that takes loop bounds.
1373 void ForallOp::build(
1374     mlir::OpBuilder &b, mlir::OperationState &result,
1375     ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1376     std::optional<ArrayAttr> mapping,
1377     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1378   unsigned numLoops = ubs.size();
1379   SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1380   SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1381   build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1382 }
1383 
1384 // Checks if the lbs are zeros and steps are ones.
1385 bool ForallOp::isNormalized() {
1386   auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1387     return llvm::all_of(results, [&](OpFoldResult ofr) {
1388       auto intValue = getConstantIntValue(ofr);
1389       return intValue.has_value() && intValue == val;
1390     });
1391   };
1392   return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1393 }
1394 
1395 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1396 // unaware of the fact that our terminator also needs a region to be
1397 // well-formed. We override it here to ensure that we do the right thing.
1398 void ForallOp::ensureTerminator(Region &region, OpBuilder &builder,
1399                                 Location loc) {
1400   OpTrait::SingleBlockImplicitTerminator<InParallelOp>::Impl<
1401       ForallOp>::ensureTerminator(region, builder, loc);
1402   auto terminator =
1403       llvm::dyn_cast<InParallelOp>(region.front().getTerminator());
1404   if (terminator.getRegion().empty())
1405     builder.createBlock(&terminator.getRegion());
1406 }
1407 
1408 InParallelOp ForallOp::getTerminator() {
1409   return cast<InParallelOp>(getBody()->getTerminator());
1410 }
1411 
1412 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1413   SmallVector<Operation *> storeOps;
1414   InParallelOp inParallelOp = getTerminator();
1415   for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1416     if (auto parallelInsertSliceOp =
1417             dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1418         parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1419       storeOps.push_back(parallelInsertSliceOp);
1420     }
1421   }
1422   return storeOps;
1423 }
1424 
1425 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1426   return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1427 }
1428 
1429 // Get lower bounds as OpFoldResult.
1430 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1431   Builder b(getOperation()->getContext());
1432   return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1433 }
1434 
1435 // Get upper bounds as OpFoldResult.
1436 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1437   Builder b(getOperation()->getContext());
1438   return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1439 }
1440 
1441 // Get steps as OpFoldResult.
1442 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1443   Builder b(getOperation()->getContext());
1444   return getMixedValues(getStaticStep(), getDynamicStep(), b);
1445 }
1446 
1447 ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) {
1448   auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1449   if (!tidxArg)
1450     return ForallOp();
1451   assert(tidxArg.getOwner() && "unlinked block argument");
1452   auto *containingOp = tidxArg.getOwner()->getParentOp();
1453   return dyn_cast<ForallOp>(containingOp);
1454 }
1455 
1456 namespace {
1457 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1458 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1459   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1460 
1461   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1462                                 PatternRewriter &rewriter) const final {
1463     auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1464     if (!forallOp)
1465       return failure();
1466     Value sharedOut =
1467         forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1468             ->get();
1469     rewriter.modifyOpInPlace(
1470         dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1471     return success();
1472   }
1473 };
1474 
1475 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1476 public:
1477   using OpRewritePattern<ForallOp>::OpRewritePattern;
1478 
1479   LogicalResult matchAndRewrite(ForallOp op,
1480                                 PatternRewriter &rewriter) const override {
1481     SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1482     SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1483     SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1484     if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1485         failed(foldDynamicIndexList(mixedUpperBound)) &&
1486         failed(foldDynamicIndexList(mixedStep)))
1487       return failure();
1488 
1489     rewriter.modifyOpInPlace(op, [&]() {
1490       SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1491       SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1492       dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1493                                  staticLowerBound);
1494       op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1495       op.setStaticLowerBound(staticLowerBound);
1496 
1497       dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1498                                  staticUpperBound);
1499       op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1500       op.setStaticUpperBound(staticUpperBound);
1501 
1502       dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1503       op.getDynamicStepMutable().assign(dynamicStep);
1504       op.setStaticStep(staticStep);
1505 
1506       op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1507                   rewriter.getDenseI32ArrayAttr(
1508                       {static_cast<int32_t>(dynamicLowerBound.size()),
1509                        static_cast<int32_t>(dynamicUpperBound.size()),
1510                        static_cast<int32_t>(dynamicStep.size()),
1511                        static_cast<int32_t>(op.getNumResults())}));
1512     });
1513     return success();
1514   }
1515 };
1516 
1517 /// The following canonicalization pattern folds the iter arguments of
1518 /// scf.forall op if :-
1519 /// 1. The corresponding result has zero uses.
1520 /// 2. The iter argument is NOT being modified within the loop body.
1521 /// uses.
1522 ///
1523 /// Example of first case :-
1524 ///  INPUT:
1525 ///   %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1526 ///            {
1527 ///                ...
1528 ///                <SOME USE OF %arg0>
1529 ///                <SOME USE OF %arg1>
1530 ///                <SOME USE OF %arg2>
1531 ///                ...
1532 ///                scf.forall.in_parallel {
1533 ///                    <STORE OP WITH DESTINATION %arg1>
1534 ///                    <STORE OP WITH DESTINATION %arg0>
1535 ///                    <STORE OP WITH DESTINATION %arg2>
1536 ///                }
1537 ///             }
1538 ///   return %res#1
1539 ///
1540 ///  OUTPUT:
1541 ///   %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1542 ///            {
1543 ///                ...
1544 ///                <SOME USE OF %a>
1545 ///                <SOME USE OF %new_arg0>
1546 ///                <SOME USE OF %c>
1547 ///                ...
1548 ///                scf.forall.in_parallel {
1549 ///                    <STORE OP WITH DESTINATION %new_arg0>
1550 ///                }
1551 ///             }
1552 ///   return %res
1553 ///
1554 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1555 ///          scf.forall is replaced by their corresponding operands.
1556 ///       2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1557 ///          of the scf.forall besides within scf.forall.in_parallel terminator,
1558 ///          this canonicalization remains valid. For more details, please refer
1559 ///          to :
1560 ///          https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1561 ///       3. TODO(avarma): Generalize it for other store ops. Currently it
1562 ///          handles tensor.parallel_insert_slice ops only.
1563 ///
1564 /// Example of second case :-
1565 ///  INPUT:
1566 ///   %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1567 ///            {
1568 ///                ...
1569 ///                <SOME USE OF %arg0>
1570 ///                <SOME USE OF %arg1>
1571 ///                ...
1572 ///                scf.forall.in_parallel {
1573 ///                    <STORE OP WITH DESTINATION %arg1>
1574 ///                }
1575 ///             }
1576 ///   return %res#0, %res#1
1577 ///
1578 ///  OUTPUT:
1579 ///   %res = scf.forall ... shared_outs(%new_arg0 = %b)
1580 ///            {
1581 ///                ...
1582 ///                <SOME USE OF %a>
1583 ///                <SOME USE OF %new_arg0>
1584 ///                ...
1585 ///                scf.forall.in_parallel {
1586 ///                    <STORE OP WITH DESTINATION %new_arg0>
1587 ///                }
1588 ///             }
1589 ///   return %a, %res
1590 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1591   using OpRewritePattern<ForallOp>::OpRewritePattern;
1592 
1593   LogicalResult matchAndRewrite(ForallOp forallOp,
1594                                 PatternRewriter &rewriter) const final {
1595     // Step 1: For a given i-th result of scf.forall, check the following :-
1596     //         a. If it has any use.
1597     //         b. If the corresponding iter argument is being modified within
1598     //            the loop, i.e. has at least one store op with the iter arg as
1599     //            its destination operand. For this we use
1600     //            ForallOp::getCombiningOps(iter_arg).
1601     //
1602     //         Based on the check we maintain the following :-
1603     //         a. `resultToDelete` - i-th result of scf.forall that'll be
1604     //            deleted.
1605     //         b. `resultToReplace` - i-th result of the old scf.forall
1606     //            whose uses will be replaced by the new scf.forall.
1607     //         c. `newOuts` - the shared_outs' operand of the new scf.forall
1608     //            corresponding to the i-th result with at least one use.
1609     SetVector<OpResult> resultToDelete;
1610     SmallVector<Value> resultToReplace;
1611     SmallVector<Value> newOuts;
1612     for (OpResult result : forallOp.getResults()) {
1613       OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1614       BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1615       if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1616         resultToDelete.insert(result);
1617       } else {
1618         resultToReplace.push_back(result);
1619         newOuts.push_back(opOperand->get());
1620       }
1621     }
1622 
1623     // Return early if all results of scf.forall have at least one use and being
1624     // modified within the loop.
1625     if (resultToDelete.empty())
1626       return failure();
1627 
1628     // Step 2: For the the i-th result, do the following :-
1629     //         a. Fetch the corresponding BlockArgument.
1630     //         b. Look for store ops (currently tensor.parallel_insert_slice)
1631     //            with the BlockArgument as its destination operand.
1632     //         c. Remove the operations fetched in b.
1633     for (OpResult result : resultToDelete) {
1634       OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1635       BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1636       SmallVector<Operation *> combiningOps =
1637           forallOp.getCombiningOps(blockArg);
1638       for (Operation *combiningOp : combiningOps)
1639         rewriter.eraseOp(combiningOp);
1640     }
1641 
1642     // Step 3. Create a new scf.forall op with the new shared_outs' operands
1643     //         fetched earlier
1644     auto newForallOp = rewriter.create<scf::ForallOp>(
1645         forallOp.getLoc(), forallOp.getMixedLowerBound(),
1646         forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1647         forallOp.getMapping(),
1648         /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1649 
1650     // Step 4. Merge the block of the old scf.forall into the newly created
1651     //         scf.forall using the new set of arguments.
1652     Block *loopBody = forallOp.getBody();
1653     Block *newLoopBody = newForallOp.getBody();
1654     ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1655     // Form initial new bbArg list with just the control operands of the new
1656     // scf.forall op.
1657     SmallVector<Value> newBlockArgs =
1658         llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1659                             [](BlockArgument b) -> Value { return b; });
1660     Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1661     unsigned index = 0;
1662     // Take the new corresponding bbArg if the old bbArg was used as a
1663     // destination in the in_parallel op. For all other bbArgs, use the
1664     // corresponding init_arg from the old scf.forall op.
1665     for (OpResult result : forallOp.getResults()) {
1666       if (resultToDelete.count(result)) {
1667         newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1668       } else {
1669         newBlockArgs.push_back(newSharedOutsArgs[index++]);
1670       }
1671     }
1672     rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1673 
1674     // Step 5. Replace the uses of result of old scf.forall with that of the new
1675     //         scf.forall.
1676     for (auto &&[oldResult, newResult] :
1677          llvm::zip(resultToReplace, newForallOp->getResults()))
1678       rewriter.replaceAllUsesWith(oldResult, newResult);
1679 
1680     // Step 6. Replace the uses of those values that either has no use or are
1681     //         not being modified within the loop with the corresponding
1682     //         OpOperand.
1683     for (OpResult oldResult : resultToDelete)
1684       rewriter.replaceAllUsesWith(oldResult,
1685                                   forallOp.getTiedOpOperand(oldResult)->get());
1686     return success();
1687   }
1688 };
1689 
1690 struct ForallOpSingleOrZeroIterationDimsFolder
1691     : public OpRewritePattern<ForallOp> {
1692   using OpRewritePattern<ForallOp>::OpRewritePattern;
1693 
1694   LogicalResult matchAndRewrite(ForallOp op,
1695                                 PatternRewriter &rewriter) const override {
1696     // Do not fold dimensions if they are mapped to processing units.
1697     if (op.getMapping().has_value() && !op.getMapping()->empty())
1698       return failure();
1699     Location loc = op.getLoc();
1700 
1701     // Compute new loop bounds that omit all single-iteration loop dimensions.
1702     SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1703         newMixedSteps;
1704     IRMapping mapping;
1705     for (auto [lb, ub, step, iv] :
1706          llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1707                    op.getMixedStep(), op.getInductionVars())) {
1708       auto numIterations = constantTripCount(lb, ub, step);
1709       if (numIterations.has_value()) {
1710         // Remove the loop if it performs zero iterations.
1711         if (*numIterations == 0) {
1712           rewriter.replaceOp(op, op.getOutputs());
1713           return success();
1714         }
1715         // Replace the loop induction variable by the lower bound if the loop
1716         // performs a single iteration. Otherwise, copy the loop bounds.
1717         if (*numIterations == 1) {
1718           mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1719           continue;
1720         }
1721       }
1722       newMixedLowerBounds.push_back(lb);
1723       newMixedUpperBounds.push_back(ub);
1724       newMixedSteps.push_back(step);
1725     }
1726 
1727     // All of the loop dimensions perform a single iteration. Inline loop body.
1728     if (newMixedLowerBounds.empty()) {
1729       promote(rewriter, op);
1730       return success();
1731     }
1732 
1733     // Exit if none of the loop dimensions perform a single iteration.
1734     if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1735       return rewriter.notifyMatchFailure(
1736           op, "no dimensions have 0 or 1 iterations");
1737     }
1738 
1739     // Replace the loop by a lower-dimensional loop.
1740     ForallOp newOp;
1741     newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1742                                       newMixedUpperBounds, newMixedSteps,
1743                                       op.getOutputs(), std::nullopt, nullptr);
1744     newOp.getBodyRegion().getBlocks().clear();
1745     // The new loop needs to keep all attributes from the old one, except for
1746     // "operandSegmentSizes" and static loop bound attributes which capture
1747     // the outdated information of the old iteration domain.
1748     SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1749                                         newOp.getStaticLowerBoundAttrName(),
1750                                         newOp.getStaticUpperBoundAttrName(),
1751                                         newOp.getStaticStepAttrName()};
1752     for (const auto &namedAttr : op->getAttrs()) {
1753       if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1754         continue;
1755       rewriter.modifyOpInPlace(newOp, [&]() {
1756         newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1757       });
1758     }
1759     rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1760                                newOp.getRegion().begin(), mapping);
1761     rewriter.replaceOp(op, newOp.getResults());
1762     return success();
1763   }
1764 };
1765 
1766 /// Replace all induction vars with a single trip count with their lower bound.
1767 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1768   using OpRewritePattern<ForallOp>::OpRewritePattern;
1769 
1770   LogicalResult matchAndRewrite(ForallOp op,
1771                                 PatternRewriter &rewriter) const override {
1772     Location loc = op.getLoc();
1773     bool changed = false;
1774     for (auto [lb, ub, step, iv] :
1775          llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1776                    op.getMixedStep(), op.getInductionVars())) {
1777       if (iv.getUses().begin() == iv.getUses().end())
1778         continue;
1779       auto numIterations = constantTripCount(lb, ub, step);
1780       if (!numIterations.has_value() || numIterations.value() != 1) {
1781         continue;
1782       }
1783       rewriter.replaceAllUsesWith(
1784           iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1785       changed = true;
1786     }
1787     return success(changed);
1788   }
1789 };
1790 
1791 struct FoldTensorCastOfOutputIntoForallOp
1792     : public OpRewritePattern<scf::ForallOp> {
1793   using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1794 
1795   struct TypeCast {
1796     Type srcType;
1797     Type dstType;
1798   };
1799 
1800   LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1801                                 PatternRewriter &rewriter) const final {
1802     llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1803     llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1804     for (auto en : llvm::enumerate(newOutputTensors)) {
1805       auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1806       if (!castOp)
1807         continue;
1808 
1809       // Only casts that that preserve static information, i.e. will make the
1810       // loop result type "more" static than before, will be folded.
1811       if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1812                                               castOp.getSource().getType())) {
1813         continue;
1814       }
1815 
1816       tensorCastProducers[en.index()] =
1817           TypeCast{castOp.getSource().getType(), castOp.getType()};
1818       newOutputTensors[en.index()] = castOp.getSource();
1819     }
1820 
1821     if (tensorCastProducers.empty())
1822       return failure();
1823 
1824     // Create new loop.
1825     Location loc = forallOp.getLoc();
1826     auto newForallOp = rewriter.create<ForallOp>(
1827         loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1828         forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1829         [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1830           auto castBlockArgs =
1831               llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1832           for (auto [index, cast] : tensorCastProducers) {
1833             Value &oldTypeBBArg = castBlockArgs[index];
1834             oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1835                 nestedLoc, cast.dstType, oldTypeBBArg);
1836           }
1837 
1838           // Move old body into new parallel loop.
1839           SmallVector<Value> ivsBlockArgs =
1840               llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1841           ivsBlockArgs.append(castBlockArgs);
1842           rewriter.mergeBlocks(forallOp.getBody(),
1843                                bbArgs.front().getParentBlock(), ivsBlockArgs);
1844         });
1845 
1846     // After `mergeBlocks` happened, the destinations in the terminator were
1847     // mapped to the tensor.cast old-typed results of the output bbArgs. The
1848     // destination have to be updated to point to the output bbArgs directly.
1849     auto terminator = newForallOp.getTerminator();
1850     for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1851              terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1852       auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1853       insertSliceOp.getDestMutable().assign(outputBlockArg);
1854     }
1855 
1856     // Cast results back to the original types.
1857     rewriter.setInsertionPointAfter(newForallOp);
1858     SmallVector<Value> castResults = newForallOp.getResults();
1859     for (auto &item : tensorCastProducers) {
1860       Value &oldTypeResult = castResults[item.first];
1861       oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1862                                                       oldTypeResult);
1863     }
1864     rewriter.replaceOp(forallOp, castResults);
1865     return success();
1866   }
1867 };
1868 
1869 } // namespace
1870 
1871 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1872                                            MLIRContext *context) {
1873   results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1874               ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1875               ForallOpSingleOrZeroIterationDimsFolder,
1876               ForallOpReplaceConstantInductionVar>(context);
1877 }
1878 
1879 /// Given the region at `index`, or the parent operation if `index` is None,
1880 /// return the successor regions. These are the regions that may be selected
1881 /// during the flow of control. `operands` is a set of optional attributes that
1882 /// correspond to a constant value for each operand, or null if that operand is
1883 /// not a constant.
1884 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1885                                    SmallVectorImpl<RegionSuccessor> &regions) {
1886   // Both the operation itself and the region may be branching into the body or
1887   // back into the operation itself. It is possible for loop not to enter the
1888   // body.
1889   regions.push_back(RegionSuccessor(&getRegion()));
1890   regions.push_back(RegionSuccessor());
1891 }
1892 
1893 //===----------------------------------------------------------------------===//
1894 // InParallelOp
1895 //===----------------------------------------------------------------------===//
1896 
1897 // Build a InParallelOp with mixed static and dynamic entries.
1898 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1899   OpBuilder::InsertionGuard g(b);
1900   Region *bodyRegion = result.addRegion();
1901   b.createBlock(bodyRegion);
1902 }
1903 
1904 LogicalResult InParallelOp::verify() {
1905   scf::ForallOp forallOp =
1906       dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1907   if (!forallOp)
1908     return this->emitOpError("expected forall op parent");
1909 
1910   // TODO: InParallelOpInterface.
1911   for (Operation &op : getRegion().front().getOperations()) {
1912     if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1913       return this->emitOpError("expected only ")
1914              << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1915     }
1916 
1917     // Verify that inserts are into out block arguments.
1918     Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1919     ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1920     if (!llvm::is_contained(regionOutArgs, dest))
1921       return op.emitOpError("may only insert into an output block argument");
1922   }
1923   return success();
1924 }
1925 
1926 void InParallelOp::print(OpAsmPrinter &p) {
1927   p << " ";
1928   p.printRegion(getRegion(),
1929                 /*printEntryBlockArgs=*/false,
1930                 /*printBlockTerminators=*/false);
1931   p.printOptionalAttrDict(getOperation()->getAttrs());
1932 }
1933 
1934 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1935   auto &builder = parser.getBuilder();
1936 
1937   SmallVector<OpAsmParser::Argument, 8> regionOperands;
1938   std::unique_ptr<Region> region = std::make_unique<Region>();
1939   if (parser.parseRegion(*region, regionOperands))
1940     return failure();
1941 
1942   if (region->empty())
1943     OpBuilder(builder.getContext()).createBlock(region.get());
1944   result.addRegion(std::move(region));
1945 
1946   // Parse the optional attribute list.
1947   if (parser.parseOptionalAttrDict(result.attributes))
1948     return failure();
1949   return success();
1950 }
1951 
1952 OpResult InParallelOp::getParentResult(int64_t idx) {
1953   return getOperation()->getParentOp()->getResult(idx);
1954 }
1955 
1956 SmallVector<BlockArgument> InParallelOp::getDests() {
1957   return llvm::to_vector<4>(
1958       llvm::map_range(getYieldingOps(), [](Operation &op) {
1959         // Add new ops here as needed.
1960         auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1961         return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1962       }));
1963 }
1964 
1965 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1966   return getRegion().front().getOperations();
1967 }
1968 
1969 //===----------------------------------------------------------------------===//
1970 // IfOp
1971 //===----------------------------------------------------------------------===//
1972 
1973 bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
1974   assert(a && "expected non-empty operation");
1975   assert(b && "expected non-empty operation");
1976 
1977   IfOp ifOp = a->getParentOfType<IfOp>();
1978   while (ifOp) {
1979     // Check if b is inside ifOp. (We already know that a is.)
1980     if (ifOp->isProperAncestor(b))
1981       // b is contained in ifOp. a and b are in mutually exclusive branches if
1982       // they are in different blocks of ifOp.
1983       return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1984              static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1985     // Check next enclosing IfOp.
1986     ifOp = ifOp->getParentOfType<IfOp>();
1987   }
1988 
1989   // Could not find a common IfOp among a's and b's ancestors.
1990   return false;
1991 }
1992 
1993 LogicalResult
1994 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1995                        IfOp::Adaptor adaptor,
1996                        SmallVectorImpl<Type> &inferredReturnTypes) {
1997   if (adaptor.getRegions().empty())
1998     return failure();
1999   Region *r = &adaptor.getThenRegion();
2000   if (r->empty())
2001     return failure();
2002   Block &b = r->front();
2003   if (b.empty())
2004     return failure();
2005   auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2006   if (!yieldOp)
2007     return failure();
2008   TypeRange types = yieldOp.getOperandTypes();
2009   inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
2010                              types.end());
2011   return success();
2012 }
2013 
2014 void IfOp::build(OpBuilder &builder, OperationState &result,
2015                  TypeRange resultTypes, Value cond) {
2016   return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2017                /*addElseBlock=*/false);
2018 }
2019 
2020 void IfOp::build(OpBuilder &builder, OperationState &result,
2021                  TypeRange resultTypes, Value cond, bool addThenBlock,
2022                  bool addElseBlock) {
2023   assert((!addElseBlock || addThenBlock) &&
2024          "must not create else block w/o then block");
2025   result.addTypes(resultTypes);
2026   result.addOperands(cond);
2027 
2028   // Add regions and blocks.
2029   OpBuilder::InsertionGuard guard(builder);
2030   Region *thenRegion = result.addRegion();
2031   if (addThenBlock)
2032     builder.createBlock(thenRegion);
2033   Region *elseRegion = result.addRegion();
2034   if (addElseBlock)
2035     builder.createBlock(elseRegion);
2036 }
2037 
2038 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2039                  bool withElseRegion) {
2040   build(builder, result, TypeRange{}, cond, withElseRegion);
2041 }
2042 
2043 void IfOp::build(OpBuilder &builder, OperationState &result,
2044                  TypeRange resultTypes, Value cond, bool withElseRegion) {
2045   result.addTypes(resultTypes);
2046   result.addOperands(cond);
2047 
2048   // Build then region.
2049   OpBuilder::InsertionGuard guard(builder);
2050   Region *thenRegion = result.addRegion();
2051   builder.createBlock(thenRegion);
2052   if (resultTypes.empty())
2053     IfOp::ensureTerminator(*thenRegion, builder, result.location);
2054 
2055   // Build else region.
2056   Region *elseRegion = result.addRegion();
2057   if (withElseRegion) {
2058     builder.createBlock(elseRegion);
2059     if (resultTypes.empty())
2060       IfOp::ensureTerminator(*elseRegion, builder, result.location);
2061   }
2062 }
2063 
2064 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2065                  function_ref<void(OpBuilder &, Location)> thenBuilder,
2066                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2067   assert(thenBuilder && "the builder callback for 'then' must be present");
2068   result.addOperands(cond);
2069 
2070   // Build then region.
2071   OpBuilder::InsertionGuard guard(builder);
2072   Region *thenRegion = result.addRegion();
2073   builder.createBlock(thenRegion);
2074   thenBuilder(builder, result.location);
2075 
2076   // Build else region.
2077   Region *elseRegion = result.addRegion();
2078   if (elseBuilder) {
2079     builder.createBlock(elseRegion);
2080     elseBuilder(builder, result.location);
2081   }
2082 
2083   // Infer result types.
2084   SmallVector<Type> inferredReturnTypes;
2085   MLIRContext *ctx = builder.getContext();
2086   auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2087   if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2088                                  /*properties=*/nullptr, result.regions,
2089                                  inferredReturnTypes))) {
2090     result.addTypes(inferredReturnTypes);
2091   }
2092 }
2093 
2094 LogicalResult IfOp::verify() {
2095   if (getNumResults() != 0 && getElseRegion().empty())
2096     return emitOpError("must have an else block if defining values");
2097   return success();
2098 }
2099 
2100 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2101   // Create the regions for 'then'.
2102   result.regions.reserve(2);
2103   Region *thenRegion = result.addRegion();
2104   Region *elseRegion = result.addRegion();
2105 
2106   auto &builder = parser.getBuilder();
2107   OpAsmParser::UnresolvedOperand cond;
2108   Type i1Type = builder.getIntegerType(1);
2109   if (parser.parseOperand(cond) ||
2110       parser.resolveOperand(cond, i1Type, result.operands))
2111     return failure();
2112   // Parse optional results type list.
2113   if (parser.parseOptionalArrowTypeList(result.types))
2114     return failure();
2115   // Parse the 'then' region.
2116   if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2117     return failure();
2118   IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2119 
2120   // If we find an 'else' keyword then parse the 'else' region.
2121   if (!parser.parseOptionalKeyword("else")) {
2122     if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2123       return failure();
2124     IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2125   }
2126 
2127   // Parse the optional attribute list.
2128   if (parser.parseOptionalAttrDict(result.attributes))
2129     return failure();
2130   return success();
2131 }
2132 
2133 void IfOp::print(OpAsmPrinter &p) {
2134   bool printBlockTerminators = false;
2135 
2136   p << " " << getCondition();
2137   if (!getResults().empty()) {
2138     p << " -> (" << getResultTypes() << ")";
2139     // Print yield explicitly if the op defines values.
2140     printBlockTerminators = true;
2141   }
2142   p << ' ';
2143   p.printRegion(getThenRegion(),
2144                 /*printEntryBlockArgs=*/false,
2145                 /*printBlockTerminators=*/printBlockTerminators);
2146 
2147   // Print the 'else' regions if it exists and has a block.
2148   auto &elseRegion = getElseRegion();
2149   if (!elseRegion.empty()) {
2150     p << " else ";
2151     p.printRegion(elseRegion,
2152                   /*printEntryBlockArgs=*/false,
2153                   /*printBlockTerminators=*/printBlockTerminators);
2154   }
2155 
2156   p.printOptionalAttrDict((*this)->getAttrs());
2157 }
2158 
2159 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2160                                SmallVectorImpl<RegionSuccessor> &regions) {
2161   // The `then` and the `else` region branch back to the parent operation.
2162   if (!point.isParent()) {
2163     regions.push_back(RegionSuccessor(getResults()));
2164     return;
2165   }
2166 
2167   regions.push_back(RegionSuccessor(&getThenRegion()));
2168 
2169   // Don't consider the else region if it is empty.
2170   Region *elseRegion = &this->getElseRegion();
2171   if (elseRegion->empty())
2172     regions.push_back(RegionSuccessor());
2173   else
2174     regions.push_back(RegionSuccessor(elseRegion));
2175 }
2176 
2177 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2178                                     SmallVectorImpl<RegionSuccessor> &regions) {
2179   FoldAdaptor adaptor(operands, *this);
2180   auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2181   if (!boolAttr || boolAttr.getValue())
2182     regions.emplace_back(&getThenRegion());
2183 
2184   // If the else region is empty, execution continues after the parent op.
2185   if (!boolAttr || !boolAttr.getValue()) {
2186     if (!getElseRegion().empty())
2187       regions.emplace_back(&getElseRegion());
2188     else
2189       regions.emplace_back(getResults());
2190   }
2191 }
2192 
2193 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2194                          SmallVectorImpl<OpFoldResult> &results) {
2195   // if (!c) then A() else B() -> if c then B() else A()
2196   if (getElseRegion().empty())
2197     return failure();
2198 
2199   arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2200   if (!xorStmt)
2201     return failure();
2202 
2203   if (!matchPattern(xorStmt.getRhs(), m_One()))
2204     return failure();
2205 
2206   getConditionMutable().assign(xorStmt.getLhs());
2207   Block *thenBlock = &getThenRegion().front();
2208   // It would be nicer to use iplist::swap, but that has no implemented
2209   // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2210   getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2211                                      getElseRegion().getBlocks());
2212   getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2213                                      getThenRegion().getBlocks(), thenBlock);
2214   return success();
2215 }
2216 
2217 void IfOp::getRegionInvocationBounds(
2218     ArrayRef<Attribute> operands,
2219     SmallVectorImpl<InvocationBounds> &invocationBounds) {
2220   if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2221     // If the condition is known, then one region is known to be executed once
2222     // and the other zero times.
2223     invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2224     invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2225   } else {
2226     // Non-constant condition. Each region may be executed 0 or 1 times.
2227     invocationBounds.assign(2, {0, 1});
2228   }
2229 }
2230 
2231 namespace {
2232 // Pattern to remove unused IfOp results.
2233 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2234   using OpRewritePattern<IfOp>::OpRewritePattern;
2235 
2236   void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2237                     PatternRewriter &rewriter) const {
2238     // Move all operations to the destination block.
2239     rewriter.mergeBlocks(source, dest);
2240     // Replace the yield op by one that returns only the used values.
2241     auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2242     SmallVector<Value, 4> usedOperands;
2243     llvm::transform(usedResults, std::back_inserter(usedOperands),
2244                     [&](OpResult result) {
2245                       return yieldOp.getOperand(result.getResultNumber());
2246                     });
2247     rewriter.modifyOpInPlace(yieldOp,
2248                              [&]() { yieldOp->setOperands(usedOperands); });
2249   }
2250 
2251   LogicalResult matchAndRewrite(IfOp op,
2252                                 PatternRewriter &rewriter) const override {
2253     // Compute the list of used results.
2254     SmallVector<OpResult, 4> usedResults;
2255     llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2256                   [](OpResult result) { return !result.use_empty(); });
2257 
2258     // Replace the operation if only a subset of its results have uses.
2259     if (usedResults.size() == op.getNumResults())
2260       return failure();
2261 
2262     // Compute the result types of the replacement operation.
2263     SmallVector<Type, 4> newTypes;
2264     llvm::transform(usedResults, std::back_inserter(newTypes),
2265                     [](OpResult result) { return result.getType(); });
2266 
2267     // Create a replacement operation with empty then and else regions.
2268     auto newOp =
2269         rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2270     rewriter.createBlock(&newOp.getThenRegion());
2271     rewriter.createBlock(&newOp.getElseRegion());
2272 
2273     // Move the bodies and replace the terminators (note there is a then and
2274     // an else region since the operation returns results).
2275     transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2276     transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2277 
2278     // Replace the operation by the new one.
2279     SmallVector<Value, 4> repResults(op.getNumResults());
2280     for (const auto &en : llvm::enumerate(usedResults))
2281       repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2282     rewriter.replaceOp(op, repResults);
2283     return success();
2284   }
2285 };
2286 
2287 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2288   using OpRewritePattern<IfOp>::OpRewritePattern;
2289 
2290   LogicalResult matchAndRewrite(IfOp op,
2291                                 PatternRewriter &rewriter) const override {
2292     BoolAttr condition;
2293     if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2294       return failure();
2295 
2296     if (condition.getValue())
2297       replaceOpWithRegion(rewriter, op, op.getThenRegion());
2298     else if (!op.getElseRegion().empty())
2299       replaceOpWithRegion(rewriter, op, op.getElseRegion());
2300     else
2301       rewriter.eraseOp(op);
2302 
2303     return success();
2304   }
2305 };
2306 
2307 /// Hoist any yielded results whose operands are defined outside
2308 /// the if, to a select instruction.
2309 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2310   using OpRewritePattern<IfOp>::OpRewritePattern;
2311 
2312   LogicalResult matchAndRewrite(IfOp op,
2313                                 PatternRewriter &rewriter) const override {
2314     if (op->getNumResults() == 0)
2315       return failure();
2316 
2317     auto cond = op.getCondition();
2318     auto thenYieldArgs = op.thenYield().getOperands();
2319     auto elseYieldArgs = op.elseYield().getOperands();
2320 
2321     SmallVector<Type> nonHoistable;
2322     for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2323       if (&op.getThenRegion() == trueVal.getParentRegion() ||
2324           &op.getElseRegion() == falseVal.getParentRegion())
2325         nonHoistable.push_back(trueVal.getType());
2326     }
2327     // Early exit if there aren't any yielded values we can
2328     // hoist outside the if.
2329     if (nonHoistable.size() == op->getNumResults())
2330       return failure();
2331 
2332     IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2333                                              /*withElseRegion=*/false);
2334     if (replacement.thenBlock())
2335       rewriter.eraseBlock(replacement.thenBlock());
2336     replacement.getThenRegion().takeBody(op.getThenRegion());
2337     replacement.getElseRegion().takeBody(op.getElseRegion());
2338 
2339     SmallVector<Value> results(op->getNumResults());
2340     assert(thenYieldArgs.size() == results.size());
2341     assert(elseYieldArgs.size() == results.size());
2342 
2343     SmallVector<Value> trueYields;
2344     SmallVector<Value> falseYields;
2345     rewriter.setInsertionPoint(replacement);
2346     for (const auto &it :
2347          llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2348       Value trueVal = std::get<0>(it.value());
2349       Value falseVal = std::get<1>(it.value());
2350       if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2351           &replacement.getElseRegion() == falseVal.getParentRegion()) {
2352         results[it.index()] = replacement.getResult(trueYields.size());
2353         trueYields.push_back(trueVal);
2354         falseYields.push_back(falseVal);
2355       } else if (trueVal == falseVal)
2356         results[it.index()] = trueVal;
2357       else
2358         results[it.index()] = rewriter.create<arith::SelectOp>(
2359             op.getLoc(), cond, trueVal, falseVal);
2360     }
2361 
2362     rewriter.setInsertionPointToEnd(replacement.thenBlock());
2363     rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2364 
2365     rewriter.setInsertionPointToEnd(replacement.elseBlock());
2366     rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2367 
2368     rewriter.replaceOp(op, results);
2369     return success();
2370   }
2371 };
2372 
2373 /// Allow the true region of an if to assume the condition is true
2374 /// and vice versa. For example:
2375 ///
2376 ///   scf.if %cmp {
2377 ///      print(%cmp)
2378 ///   }
2379 ///
2380 ///  becomes
2381 ///
2382 ///   scf.if %cmp {
2383 ///      print(true)
2384 ///   }
2385 ///
2386 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2387   using OpRewritePattern<IfOp>::OpRewritePattern;
2388 
2389   LogicalResult matchAndRewrite(IfOp op,
2390                                 PatternRewriter &rewriter) const override {
2391     // Early exit if the condition is constant since replacing a constant
2392     // in the body with another constant isn't a simplification.
2393     if (matchPattern(op.getCondition(), m_Constant()))
2394       return failure();
2395 
2396     bool changed = false;
2397     mlir::Type i1Ty = rewriter.getI1Type();
2398 
2399     // These variables serve to prevent creating duplicate constants
2400     // and hold constant true or false values.
2401     Value constantTrue = nullptr;
2402     Value constantFalse = nullptr;
2403 
2404     for (OpOperand &use :
2405          llvm::make_early_inc_range(op.getCondition().getUses())) {
2406       if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2407         changed = true;
2408 
2409         if (!constantTrue)
2410           constantTrue = rewriter.create<arith::ConstantOp>(
2411               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2412 
2413         rewriter.modifyOpInPlace(use.getOwner(),
2414                                  [&]() { use.set(constantTrue); });
2415       } else if (op.getElseRegion().isAncestor(
2416                      use.getOwner()->getParentRegion())) {
2417         changed = true;
2418 
2419         if (!constantFalse)
2420           constantFalse = rewriter.create<arith::ConstantOp>(
2421               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2422 
2423         rewriter.modifyOpInPlace(use.getOwner(),
2424                                  [&]() { use.set(constantFalse); });
2425       }
2426     }
2427 
2428     return success(changed);
2429   }
2430 };
2431 
2432 /// Remove any statements from an if that are equivalent to the condition
2433 /// or its negation. For example:
2434 ///
2435 ///    %res:2 = scf.if %cmp {
2436 ///       yield something(), true
2437 ///    } else {
2438 ///       yield something2(), false
2439 ///    }
2440 ///    print(%res#1)
2441 ///
2442 ///  becomes
2443 ///    %res = scf.if %cmp {
2444 ///       yield something()
2445 ///    } else {
2446 ///       yield something2()
2447 ///    }
2448 ///    print(%cmp)
2449 ///
2450 /// Additionally if both branches yield the same value, replace all uses
2451 /// of the result with the yielded value.
2452 ///
2453 ///    %res:2 = scf.if %cmp {
2454 ///       yield something(), %arg1
2455 ///    } else {
2456 ///       yield something2(), %arg1
2457 ///    }
2458 ///    print(%res#1)
2459 ///
2460 ///  becomes
2461 ///    %res = scf.if %cmp {
2462 ///       yield something()
2463 ///    } else {
2464 ///       yield something2()
2465 ///    }
2466 ///    print(%arg1)
2467 ///
2468 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2469   using OpRewritePattern<IfOp>::OpRewritePattern;
2470 
2471   LogicalResult matchAndRewrite(IfOp op,
2472                                 PatternRewriter &rewriter) const override {
2473     // Early exit if there are no results that could be replaced.
2474     if (op.getNumResults() == 0)
2475       return failure();
2476 
2477     auto trueYield =
2478         cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2479     auto falseYield =
2480         cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2481 
2482     rewriter.setInsertionPoint(op->getBlock(),
2483                                op.getOperation()->getIterator());
2484     bool changed = false;
2485     Type i1Ty = rewriter.getI1Type();
2486     for (auto [trueResult, falseResult, opResult] :
2487          llvm::zip(trueYield.getResults(), falseYield.getResults(),
2488                    op.getResults())) {
2489       if (trueResult == falseResult) {
2490         if (!opResult.use_empty()) {
2491           opResult.replaceAllUsesWith(trueResult);
2492           changed = true;
2493         }
2494         continue;
2495       }
2496 
2497       BoolAttr trueYield, falseYield;
2498       if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2499           !matchPattern(falseResult, m_Constant(&falseYield)))
2500         continue;
2501 
2502       bool trueVal = trueYield.getValue();
2503       bool falseVal = falseYield.getValue();
2504       if (!trueVal && falseVal) {
2505         if (!opResult.use_empty()) {
2506           Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2507           Value notCond = rewriter.create<arith::XOrIOp>(
2508               op.getLoc(), op.getCondition(),
2509               constDialect
2510                   ->materializeConstant(rewriter,
2511                                         rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2512                                         op.getLoc())
2513                   ->getResult(0));
2514           opResult.replaceAllUsesWith(notCond);
2515           changed = true;
2516         }
2517       }
2518       if (trueVal && !falseVal) {
2519         if (!opResult.use_empty()) {
2520           opResult.replaceAllUsesWith(op.getCondition());
2521           changed = true;
2522         }
2523       }
2524     }
2525     return success(changed);
2526   }
2527 };
2528 
2529 /// Merge any consecutive scf.if's with the same condition.
2530 ///
2531 ///    scf.if %cond {
2532 ///       firstCodeTrue();...
2533 ///    } else {
2534 ///       firstCodeFalse();...
2535 ///    }
2536 ///    %res = scf.if %cond {
2537 ///       secondCodeTrue();...
2538 ///    } else {
2539 ///       secondCodeFalse();...
2540 ///    }
2541 ///
2542 ///  becomes
2543 ///    %res = scf.if %cmp {
2544 ///       firstCodeTrue();...
2545 ///       secondCodeTrue();...
2546 ///    } else {
2547 ///       firstCodeFalse();...
2548 ///       secondCodeFalse();...
2549 ///    }
2550 struct CombineIfs : public OpRewritePattern<IfOp> {
2551   using OpRewritePattern<IfOp>::OpRewritePattern;
2552 
2553   LogicalResult matchAndRewrite(IfOp nextIf,
2554                                 PatternRewriter &rewriter) const override {
2555     Block *parent = nextIf->getBlock();
2556     if (nextIf == &parent->front())
2557       return failure();
2558 
2559     auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2560     if (!prevIf)
2561       return failure();
2562 
2563     // Determine the logical then/else blocks when prevIf's
2564     // condition is used. Null means the block does not exist
2565     // in that case (e.g. empty else). If neither of these
2566     // are set, the two conditions cannot be compared.
2567     Block *nextThen = nullptr;
2568     Block *nextElse = nullptr;
2569     if (nextIf.getCondition() == prevIf.getCondition()) {
2570       nextThen = nextIf.thenBlock();
2571       if (!nextIf.getElseRegion().empty())
2572         nextElse = nextIf.elseBlock();
2573     }
2574     if (arith::XOrIOp notv =
2575             nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2576       if (notv.getLhs() == prevIf.getCondition() &&
2577           matchPattern(notv.getRhs(), m_One())) {
2578         nextElse = nextIf.thenBlock();
2579         if (!nextIf.getElseRegion().empty())
2580           nextThen = nextIf.elseBlock();
2581       }
2582     }
2583     if (arith::XOrIOp notv =
2584             prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2585       if (notv.getLhs() == nextIf.getCondition() &&
2586           matchPattern(notv.getRhs(), m_One())) {
2587         nextElse = nextIf.thenBlock();
2588         if (!nextIf.getElseRegion().empty())
2589           nextThen = nextIf.elseBlock();
2590       }
2591     }
2592 
2593     if (!nextThen && !nextElse)
2594       return failure();
2595 
2596     SmallVector<Value> prevElseYielded;
2597     if (!prevIf.getElseRegion().empty())
2598       prevElseYielded = prevIf.elseYield().getOperands();
2599     // Replace all uses of return values of op within nextIf with the
2600     // corresponding yields
2601     for (auto it : llvm::zip(prevIf.getResults(),
2602                              prevIf.thenYield().getOperands(), prevElseYielded))
2603       for (OpOperand &use :
2604            llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2605         if (nextThen && nextThen->getParent()->isAncestor(
2606                             use.getOwner()->getParentRegion())) {
2607           rewriter.startOpModification(use.getOwner());
2608           use.set(std::get<1>(it));
2609           rewriter.finalizeOpModification(use.getOwner());
2610         } else if (nextElse && nextElse->getParent()->isAncestor(
2611                                    use.getOwner()->getParentRegion())) {
2612           rewriter.startOpModification(use.getOwner());
2613           use.set(std::get<2>(it));
2614           rewriter.finalizeOpModification(use.getOwner());
2615         }
2616       }
2617 
2618     SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2619     llvm::append_range(mergedTypes, nextIf.getResultTypes());
2620 
2621     IfOp combinedIf = rewriter.create<IfOp>(
2622         nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2623     rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2624 
2625     rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2626                                 combinedIf.getThenRegion(),
2627                                 combinedIf.getThenRegion().begin());
2628 
2629     if (nextThen) {
2630       YieldOp thenYield = combinedIf.thenYield();
2631       YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2632       rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2633       rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2634 
2635       SmallVector<Value> mergedYields(thenYield.getOperands());
2636       llvm::append_range(mergedYields, thenYield2.getOperands());
2637       rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2638       rewriter.eraseOp(thenYield);
2639       rewriter.eraseOp(thenYield2);
2640     }
2641 
2642     rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2643                                 combinedIf.getElseRegion(),
2644                                 combinedIf.getElseRegion().begin());
2645 
2646     if (nextElse) {
2647       if (combinedIf.getElseRegion().empty()) {
2648         rewriter.inlineRegionBefore(*nextElse->getParent(),
2649                                     combinedIf.getElseRegion(),
2650                                     combinedIf.getElseRegion().begin());
2651       } else {
2652         YieldOp elseYield = combinedIf.elseYield();
2653         YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2654         rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2655 
2656         rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2657 
2658         SmallVector<Value> mergedElseYields(elseYield.getOperands());
2659         llvm::append_range(mergedElseYields, elseYield2.getOperands());
2660 
2661         rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2662         rewriter.eraseOp(elseYield);
2663         rewriter.eraseOp(elseYield2);
2664       }
2665     }
2666 
2667     SmallVector<Value> prevValues;
2668     SmallVector<Value> nextValues;
2669     for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2670       if (pair.index() < prevIf.getNumResults())
2671         prevValues.push_back(pair.value());
2672       else
2673         nextValues.push_back(pair.value());
2674     }
2675     rewriter.replaceOp(prevIf, prevValues);
2676     rewriter.replaceOp(nextIf, nextValues);
2677     return success();
2678   }
2679 };
2680 
2681 /// Pattern to remove an empty else branch.
2682 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2683   using OpRewritePattern<IfOp>::OpRewritePattern;
2684 
2685   LogicalResult matchAndRewrite(IfOp ifOp,
2686                                 PatternRewriter &rewriter) const override {
2687     // Cannot remove else region when there are operation results.
2688     if (ifOp.getNumResults())
2689       return failure();
2690     Block *elseBlock = ifOp.elseBlock();
2691     if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2692       return failure();
2693     auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2694     rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2695                                 newIfOp.getThenRegion().begin());
2696     rewriter.eraseOp(ifOp);
2697     return success();
2698   }
2699 };
2700 
2701 /// Convert nested `if`s into `arith.andi` + single `if`.
2702 ///
2703 ///    scf.if %arg0 {
2704 ///      scf.if %arg1 {
2705 ///        ...
2706 ///        scf.yield
2707 ///      }
2708 ///      scf.yield
2709 ///    }
2710 ///  becomes
2711 ///
2712 ///    %0 = arith.andi %arg0, %arg1
2713 ///    scf.if %0 {
2714 ///      ...
2715 ///      scf.yield
2716 ///    }
2717 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2718   using OpRewritePattern<IfOp>::OpRewritePattern;
2719 
2720   LogicalResult matchAndRewrite(IfOp op,
2721                                 PatternRewriter &rewriter) const override {
2722     auto nestedOps = op.thenBlock()->without_terminator();
2723     // Nested `if` must be the only op in block.
2724     if (!llvm::hasSingleElement(nestedOps))
2725       return failure();
2726 
2727     // If there is an else block, it can only yield
2728     if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2729       return failure();
2730 
2731     auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2732     if (!nestedIf)
2733       return failure();
2734 
2735     if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2736       return failure();
2737 
2738     SmallVector<Value> thenYield(op.thenYield().getOperands());
2739     SmallVector<Value> elseYield;
2740     if (op.elseBlock())
2741       llvm::append_range(elseYield, op.elseYield().getOperands());
2742 
2743     // A list of indices for which we should upgrade the value yielded
2744     // in the else to a select.
2745     SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2746 
2747     // If the outer scf.if yields a value produced by the inner scf.if,
2748     // only permit combining if the value yielded when the condition
2749     // is false in the outer scf.if is the same value yielded when the
2750     // inner scf.if condition is false.
2751     // Note that the array access to elseYield will not go out of bounds
2752     // since it must have the same length as thenYield, since they both
2753     // come from the same scf.if.
2754     for (const auto &tup : llvm::enumerate(thenYield)) {
2755       if (tup.value().getDefiningOp() == nestedIf) {
2756         auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2757         if (nestedIf.elseYield().getOperand(nestedIdx) !=
2758             elseYield[tup.index()]) {
2759           return failure();
2760         }
2761         // If the correctness test passes, we will yield
2762         // corresponding value from the inner scf.if
2763         thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2764         continue;
2765       }
2766 
2767       // Otherwise, we need to ensure the else block of the combined
2768       // condition still returns the same value when the outer condition is
2769       // true and the inner condition is false. This can be accomplished if
2770       // the then value is defined outside the outer scf.if and we replace the
2771       // value with a select that considers just the outer condition. Since
2772       // the else region contains just the yield, its yielded value is
2773       // defined outside the scf.if, by definition.
2774 
2775       // If the then value is defined within the scf.if, bail.
2776       if (tup.value().getParentRegion() == &op.getThenRegion()) {
2777         return failure();
2778       }
2779       elseYieldsToUpgradeToSelect.push_back(tup.index());
2780     }
2781 
2782     Location loc = op.getLoc();
2783     Value newCondition = rewriter.create<arith::AndIOp>(
2784         loc, op.getCondition(), nestedIf.getCondition());
2785     auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2786     Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2787 
2788     SmallVector<Value> results;
2789     llvm::append_range(results, newIf.getResults());
2790     rewriter.setInsertionPoint(newIf);
2791 
2792     for (auto idx : elseYieldsToUpgradeToSelect)
2793       results[idx] = rewriter.create<arith::SelectOp>(
2794           op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2795 
2796     rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2797     rewriter.setInsertionPointToEnd(newIf.thenBlock());
2798     rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2799     if (!elseYield.empty()) {
2800       rewriter.createBlock(&newIf.getElseRegion());
2801       rewriter.setInsertionPointToEnd(newIf.elseBlock());
2802       rewriter.create<YieldOp>(loc, elseYield);
2803     }
2804     rewriter.replaceOp(op, results);
2805     return success();
2806   }
2807 };
2808 
2809 } // namespace
2810 
2811 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2812                                        MLIRContext *context) {
2813   results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2814               ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2815               RemoveStaticCondition, RemoveUnusedResults,
2816               ReplaceIfYieldWithConditionOrValue>(context);
2817 }
2818 
2819 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2820 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2821 Block *IfOp::elseBlock() {
2822   Region &r = getElseRegion();
2823   if (r.empty())
2824     return nullptr;
2825   return &r.back();
2826 }
2827 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2828 
2829 //===----------------------------------------------------------------------===//
2830 // ParallelOp
2831 //===----------------------------------------------------------------------===//
2832 
2833 void ParallelOp::build(
2834     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2835     ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2836     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
2837         bodyBuilderFn) {
2838   result.addOperands(lowerBounds);
2839   result.addOperands(upperBounds);
2840   result.addOperands(steps);
2841   result.addOperands(initVals);
2842   result.addAttribute(
2843       ParallelOp::getOperandSegmentSizeAttr(),
2844       builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2845                                     static_cast<int32_t>(upperBounds.size()),
2846                                     static_cast<int32_t>(steps.size()),
2847                                     static_cast<int32_t>(initVals.size())}));
2848   result.addTypes(initVals.getTypes());
2849 
2850   OpBuilder::InsertionGuard guard(builder);
2851   unsigned numIVs = steps.size();
2852   SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2853   SmallVector<Location, 8> argLocs(numIVs, result.location);
2854   Region *bodyRegion = result.addRegion();
2855   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2856 
2857   if (bodyBuilderFn) {
2858     builder.setInsertionPointToStart(bodyBlock);
2859     bodyBuilderFn(builder, result.location,
2860                   bodyBlock->getArguments().take_front(numIVs),
2861                   bodyBlock->getArguments().drop_front(numIVs));
2862   }
2863   // Add terminator only if there are no reductions.
2864   if (initVals.empty())
2865     ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2866 }
2867 
2868 void ParallelOp::build(
2869     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2870     ValueRange upperBounds, ValueRange steps,
2871     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2872   // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2873   // we don't capture a reference to a temporary by constructing the lambda at
2874   // function level.
2875   auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2876                                            Location nestedLoc, ValueRange ivs,
2877                                            ValueRange) {
2878     bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2879   };
2880   function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2881   if (bodyBuilderFn)
2882     wrapper = wrappedBuilderFn;
2883 
2884   build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2885         wrapper);
2886 }
2887 
2888 LogicalResult ParallelOp::verify() {
2889   // Check that there is at least one value in lowerBound, upperBound and step.
2890   // It is sufficient to test only step, because it is ensured already that the
2891   // number of elements in lowerBound, upperBound and step are the same.
2892   Operation::operand_range stepValues = getStep();
2893   if (stepValues.empty())
2894     return emitOpError(
2895         "needs at least one tuple element for lowerBound, upperBound and step");
2896 
2897   // Check whether all constant step values are positive.
2898   for (Value stepValue : stepValues)
2899     if (auto cst = getConstantIntValue(stepValue))
2900       if (*cst <= 0)
2901         return emitOpError("constant step operand must be positive");
2902 
2903   // Check that the body defines the same number of block arguments as the
2904   // number of tuple elements in step.
2905   Block *body = getBody();
2906   if (body->getNumArguments() != stepValues.size())
2907     return emitOpError() << "expects the same number of induction variables: "
2908                          << body->getNumArguments()
2909                          << " as bound and step values: " << stepValues.size();
2910   for (auto arg : body->getArguments())
2911     if (!arg.getType().isIndex())
2912       return emitOpError(
2913           "expects arguments for the induction variable to be of index type");
2914 
2915   // Check that the terminator is an scf.reduce op.
2916   auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2917       *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2918   if (!reduceOp)
2919     return failure();
2920 
2921   // Check that the number of results is the same as the number of reductions.
2922   auto resultsSize = getResults().size();
2923   auto reductionsSize = reduceOp.getReductions().size();
2924   auto initValsSize = getInitVals().size();
2925   if (resultsSize != reductionsSize)
2926     return emitOpError() << "expects number of results: " << resultsSize
2927                          << " to be the same as number of reductions: "
2928                          << reductionsSize;
2929   if (resultsSize != initValsSize)
2930     return emitOpError() << "expects number of results: " << resultsSize
2931                          << " to be the same as number of initial values: "
2932                          << initValsSize;
2933 
2934   // Check that the types of the results and reductions are the same.
2935   for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2936     auto resultType = getOperation()->getResult(i).getType();
2937     auto reductionOperandType = reduceOp.getOperands()[i].getType();
2938     if (resultType != reductionOperandType)
2939       return reduceOp.emitOpError()
2940              << "expects type of " << i
2941              << "-th reduction operand: " << reductionOperandType
2942              << " to be the same as the " << i
2943              << "-th result type: " << resultType;
2944   }
2945   return success();
2946 }
2947 
2948 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2949   auto &builder = parser.getBuilder();
2950   // Parse an opening `(` followed by induction variables followed by `)`
2951   SmallVector<OpAsmParser::Argument, 4> ivs;
2952   if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
2953     return failure();
2954 
2955   // Parse loop bounds.
2956   SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2957   if (parser.parseEqual() ||
2958       parser.parseOperandList(lower, ivs.size(),
2959                               OpAsmParser::Delimiter::Paren) ||
2960       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2961     return failure();
2962 
2963   SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2964   if (parser.parseKeyword("to") ||
2965       parser.parseOperandList(upper, ivs.size(),
2966                               OpAsmParser::Delimiter::Paren) ||
2967       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2968     return failure();
2969 
2970   // Parse step values.
2971   SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2972   if (parser.parseKeyword("step") ||
2973       parser.parseOperandList(steps, ivs.size(),
2974                               OpAsmParser::Delimiter::Paren) ||
2975       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2976     return failure();
2977 
2978   // Parse init values.
2979   SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2980   if (succeeded(parser.parseOptionalKeyword("init"))) {
2981     if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2982       return failure();
2983   }
2984 
2985   // Parse optional results in case there is a reduce.
2986   if (parser.parseOptionalArrowTypeList(result.types))
2987     return failure();
2988 
2989   // Now parse the body.
2990   Region *body = result.addRegion();
2991   for (auto &iv : ivs)
2992     iv.type = builder.getIndexType();
2993   if (parser.parseRegion(*body, ivs))
2994     return failure();
2995 
2996   // Set `operandSegmentSizes` attribute.
2997   result.addAttribute(
2998       ParallelOp::getOperandSegmentSizeAttr(),
2999       builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3000                                     static_cast<int32_t>(upper.size()),
3001                                     static_cast<int32_t>(steps.size()),
3002                                     static_cast<int32_t>(initVals.size())}));
3003 
3004   // Parse attributes.
3005   if (parser.parseOptionalAttrDict(result.attributes) ||
3006       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3007                              result.operands))
3008     return failure();
3009 
3010   // Add a terminator if none was parsed.
3011   ParallelOp::ensureTerminator(*body, builder, result.location);
3012   return success();
3013 }
3014 
3015 void ParallelOp::print(OpAsmPrinter &p) {
3016   p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3017     << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3018   if (!getInitVals().empty())
3019     p << " init (" << getInitVals() << ")";
3020   p.printOptionalArrowTypeList(getResultTypes());
3021   p << ' ';
3022   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3023   p.printOptionalAttrDict(
3024       (*this)->getAttrs(),
3025       /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3026 }
3027 
3028 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3029 
3030 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3031   return SmallVector<Value>{getBody()->getArguments()};
3032 }
3033 
3034 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3035   return getLowerBound();
3036 }
3037 
3038 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3039   return getUpperBound();
3040 }
3041 
3042 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3043   return getStep();
3044 }
3045 
3046 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
3047   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3048   if (!ivArg)
3049     return ParallelOp();
3050   assert(ivArg.getOwner() && "unlinked block argument");
3051   auto *containingOp = ivArg.getOwner()->getParentOp();
3052   return dyn_cast<ParallelOp>(containingOp);
3053 }
3054 
3055 namespace {
3056 // Collapse loop dimensions that perform a single iteration.
3057 struct ParallelOpSingleOrZeroIterationDimsFolder
3058     : public OpRewritePattern<ParallelOp> {
3059   using OpRewritePattern<ParallelOp>::OpRewritePattern;
3060 
3061   LogicalResult matchAndRewrite(ParallelOp op,
3062                                 PatternRewriter &rewriter) const override {
3063     Location loc = op.getLoc();
3064 
3065     // Compute new loop bounds that omit all single-iteration loop dimensions.
3066     SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3067     IRMapping mapping;
3068     for (auto [lb, ub, step, iv] :
3069          llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3070                    op.getInductionVars())) {
3071       auto numIterations = constantTripCount(lb, ub, step);
3072       if (numIterations.has_value()) {
3073         // Remove the loop if it performs zero iterations.
3074         if (*numIterations == 0) {
3075           rewriter.replaceOp(op, op.getInitVals());
3076           return success();
3077         }
3078         // Replace the loop induction variable by the lower bound if the loop
3079         // performs a single iteration. Otherwise, copy the loop bounds.
3080         if (*numIterations == 1) {
3081           mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3082           continue;
3083         }
3084       }
3085       newLowerBounds.push_back(lb);
3086       newUpperBounds.push_back(ub);
3087       newSteps.push_back(step);
3088     }
3089     // Exit if none of the loop dimensions perform a single iteration.
3090     if (newLowerBounds.size() == op.getLowerBound().size())
3091       return failure();
3092 
3093     if (newLowerBounds.empty()) {
3094       // All of the loop dimensions perform a single iteration. Inline
3095       // loop body and nested ReduceOp's
3096       SmallVector<Value> results;
3097       results.reserve(op.getInitVals().size());
3098       for (auto &bodyOp : op.getBody()->without_terminator())
3099         rewriter.clone(bodyOp, mapping);
3100       auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3101       for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3102         Block &reduceBlock = reduceOp.getReductions()[i].front();
3103         auto initValIndex = results.size();
3104         mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3105         mapping.map(reduceBlock.getArgument(1),
3106                     mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3107         for (auto &reduceBodyOp : reduceBlock.without_terminator())
3108           rewriter.clone(reduceBodyOp, mapping);
3109 
3110         auto result = mapping.lookupOrDefault(
3111             cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3112         results.push_back(result);
3113       }
3114 
3115       rewriter.replaceOp(op, results);
3116       return success();
3117     }
3118     // Replace the parallel loop by lower-dimensional parallel loop.
3119     auto newOp =
3120         rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3121                                     newSteps, op.getInitVals(), nullptr);
3122     // Erase the empty block that was inserted by the builder.
3123     rewriter.eraseBlock(newOp.getBody());
3124     // Clone the loop body and remap the block arguments of the collapsed loops
3125     // (inlining does not support a cancellable block argument mapping).
3126     rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3127                                newOp.getRegion().begin(), mapping);
3128     rewriter.replaceOp(op, newOp.getResults());
3129     return success();
3130   }
3131 };
3132 
3133 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3134   using OpRewritePattern<ParallelOp>::OpRewritePattern;
3135 
3136   LogicalResult matchAndRewrite(ParallelOp op,
3137                                 PatternRewriter &rewriter) const override {
3138     Block &outerBody = *op.getBody();
3139     if (!llvm::hasSingleElement(outerBody.without_terminator()))
3140       return failure();
3141 
3142     auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3143     if (!innerOp)
3144       return failure();
3145 
3146     for (auto val : outerBody.getArguments())
3147       if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3148           llvm::is_contained(innerOp.getUpperBound(), val) ||
3149           llvm::is_contained(innerOp.getStep(), val))
3150         return failure();
3151 
3152     // Reductions are not supported yet.
3153     if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3154       return failure();
3155 
3156     auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3157                            ValueRange iterVals, ValueRange) {
3158       Block &innerBody = *innerOp.getBody();
3159       assert(iterVals.size() ==
3160              (outerBody.getNumArguments() + innerBody.getNumArguments()));
3161       IRMapping mapping;
3162       mapping.map(outerBody.getArguments(),
3163                   iterVals.take_front(outerBody.getNumArguments()));
3164       mapping.map(innerBody.getArguments(),
3165                   iterVals.take_back(innerBody.getNumArguments()));
3166       for (Operation &op : innerBody.without_terminator())
3167         builder.clone(op, mapping);
3168     };
3169 
3170     auto concatValues = [](const auto &first, const auto &second) {
3171       SmallVector<Value> ret;
3172       ret.reserve(first.size() + second.size());
3173       ret.assign(first.begin(), first.end());
3174       ret.append(second.begin(), second.end());
3175       return ret;
3176     };
3177 
3178     auto newLowerBounds =
3179         concatValues(op.getLowerBound(), innerOp.getLowerBound());
3180     auto newUpperBounds =
3181         concatValues(op.getUpperBound(), innerOp.getUpperBound());
3182     auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3183 
3184     rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3185                                             newSteps, std::nullopt,
3186                                             bodyBuilder);
3187     return success();
3188   }
3189 };
3190 
3191 } // namespace
3192 
3193 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3194                                              MLIRContext *context) {
3195   results
3196       .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3197           context);
3198 }
3199 
3200 /// Given the region at `index`, or the parent operation if `index` is None,
3201 /// return the successor regions. These are the regions that may be selected
3202 /// during the flow of control. `operands` is a set of optional attributes that
3203 /// correspond to a constant value for each operand, or null if that operand is
3204 /// not a constant.
3205 void ParallelOp::getSuccessorRegions(
3206     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3207   // Both the operation itself and the region may be branching into the body or
3208   // back into the operation itself. It is possible for loop not to enter the
3209   // body.
3210   regions.push_back(RegionSuccessor(&getRegion()));
3211   regions.push_back(RegionSuccessor());
3212 }
3213 
3214 //===----------------------------------------------------------------------===//
3215 // ReduceOp
3216 //===----------------------------------------------------------------------===//
3217 
3218 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3219 
3220 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3221                      ValueRange operands) {
3222   result.addOperands(operands);
3223   for (Value v : operands) {
3224     OpBuilder::InsertionGuard guard(builder);
3225     Region *bodyRegion = result.addRegion();
3226     builder.createBlock(bodyRegion, {},
3227                         ArrayRef<Type>{v.getType(), v.getType()},
3228                         {result.location, result.location});
3229   }
3230 }
3231 
3232 LogicalResult ReduceOp::verifyRegions() {
3233   // The region of a ReduceOp has two arguments of the same type as its
3234   // corresponding operand.
3235   for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3236     auto type = getOperands()[i].getType();
3237     Block &block = getReductions()[i].front();
3238     if (block.empty())
3239       return emitOpError() << i << "-th reduction has an empty body";
3240     if (block.getNumArguments() != 2 ||
3241         llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3242           return arg.getType() != type;
3243         }))
3244       return emitOpError() << "expected two block arguments with type " << type
3245                            << " in the " << i << "-th reduction region";
3246 
3247     // Check that the block is terminated by a ReduceReturnOp.
3248     if (!isa<ReduceReturnOp>(block.getTerminator()))
3249       return emitOpError("reduction bodies must be terminated with an "
3250                          "'scf.reduce.return' op");
3251   }
3252 
3253   return success();
3254 }
3255 
3256 MutableOperandRange
3257 ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) {
3258   // No operands are forwarded to the next iteration.
3259   return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3260 }
3261 
3262 //===----------------------------------------------------------------------===//
3263 // ReduceReturnOp
3264 //===----------------------------------------------------------------------===//
3265 
3266 LogicalResult ReduceReturnOp::verify() {
3267   // The type of the return value should be the same type as the types of the
3268   // block arguments of the reduction body.
3269   Block *reductionBody = getOperation()->getBlock();
3270   // Should already be verified by an op trait.
3271   assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3272   Type expectedResultType = reductionBody->getArgument(0).getType();
3273   if (expectedResultType != getResult().getType())
3274     return emitOpError() << "must have type " << expectedResultType
3275                          << " (the type of the reduction inputs)";
3276   return success();
3277 }
3278 
3279 //===----------------------------------------------------------------------===//
3280 // WhileOp
3281 //===----------------------------------------------------------------------===//
3282 
3283 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3284                     ::mlir::OperationState &odsState, TypeRange resultTypes,
3285                     ValueRange inits, BodyBuilderFn beforeBuilder,
3286                     BodyBuilderFn afterBuilder) {
3287   odsState.addOperands(inits);
3288   odsState.addTypes(resultTypes);
3289 
3290   OpBuilder::InsertionGuard guard(odsBuilder);
3291 
3292   // Build before region.
3293   SmallVector<Location, 4> beforeArgLocs;
3294   beforeArgLocs.reserve(inits.size());
3295   for (Value operand : inits) {
3296     beforeArgLocs.push_back(operand.getLoc());
3297   }
3298 
3299   Region *beforeRegion = odsState.addRegion();
3300   Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3301                                               inits.getTypes(), beforeArgLocs);
3302   if (beforeBuilder)
3303     beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3304 
3305   // Build after region.
3306   SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3307 
3308   Region *afterRegion = odsState.addRegion();
3309   Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3310                                              resultTypes, afterArgLocs);
3311 
3312   if (afterBuilder)
3313     afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3314 }
3315 
3316 ConditionOp WhileOp::getConditionOp() {
3317   return cast<ConditionOp>(getBeforeBody()->getTerminator());
3318 }
3319 
3320 YieldOp WhileOp::getYieldOp() {
3321   return cast<YieldOp>(getAfterBody()->getTerminator());
3322 }
3323 
3324 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3325   return getYieldOp().getResultsMutable();
3326 }
3327 
3328 Block::BlockArgListType WhileOp::getBeforeArguments() {
3329   return getBeforeBody()->getArguments();
3330 }
3331 
3332 Block::BlockArgListType WhileOp::getAfterArguments() {
3333   return getAfterBody()->getArguments();
3334 }
3335 
3336 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3337   return getBeforeArguments();
3338 }
3339 
3340 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3341   assert(point == getBefore() &&
3342          "WhileOp is expected to branch only to the first region");
3343   return getInits();
3344 }
3345 
3346 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3347                                   SmallVectorImpl<RegionSuccessor> &regions) {
3348   // The parent op always branches to the condition region.
3349   if (point.isParent()) {
3350     regions.emplace_back(&getBefore(), getBefore().getArguments());
3351     return;
3352   }
3353 
3354   assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3355          "there are only two regions in a WhileOp");
3356   // The body region always branches back to the condition region.
3357   if (point == getAfter()) {
3358     regions.emplace_back(&getBefore(), getBefore().getArguments());
3359     return;
3360   }
3361 
3362   regions.emplace_back(getResults());
3363   regions.emplace_back(&getAfter(), getAfter().getArguments());
3364 }
3365 
3366 SmallVector<Region *> WhileOp::getLoopRegions() {
3367   return {&getBefore(), &getAfter()};
3368 }
3369 
3370 /// Parses a `while` op.
3371 ///
3372 /// op ::= `scf.while` assignments `:` function-type region `do` region
3373 ///         `attributes` attribute-dict
3374 /// initializer ::= /* empty */ | `(` assignment-list `)`
3375 /// assignment-list ::= assignment | assignment `,` assignment-list
3376 /// assignment ::= ssa-value `=` ssa-value
3377 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3378   SmallVector<OpAsmParser::Argument, 4> regionArgs;
3379   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3380   Region *before = result.addRegion();
3381   Region *after = result.addRegion();
3382 
3383   OptionalParseResult listResult =
3384       parser.parseOptionalAssignmentList(regionArgs, operands);
3385   if (listResult.has_value() && failed(listResult.value()))
3386     return failure();
3387 
3388   FunctionType functionType;
3389   SMLoc typeLoc = parser.getCurrentLocation();
3390   if (failed(parser.parseColonType(functionType)))
3391     return failure();
3392 
3393   result.addTypes(functionType.getResults());
3394 
3395   if (functionType.getNumInputs() != operands.size()) {
3396     return parser.emitError(typeLoc)
3397            << "expected as many input types as operands "
3398            << "(expected " << operands.size() << " got "
3399            << functionType.getNumInputs() << ")";
3400   }
3401 
3402   // Resolve input operands.
3403   if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3404                                     parser.getCurrentLocation(),
3405                                     result.operands)))
3406     return failure();
3407 
3408   // Propagate the types into the region arguments.
3409   for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3410     regionArgs[i].type = functionType.getInput(i);
3411 
3412   return failure(parser.parseRegion(*before, regionArgs) ||
3413                  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3414                  parser.parseOptionalAttrDictWithKeyword(result.attributes));
3415 }
3416 
3417 /// Prints a `while` op.
3418 void scf::WhileOp::print(OpAsmPrinter &p) {
3419   printInitializationList(p, getBeforeArguments(), getInits(), " ");
3420   p << " : ";
3421   p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3422   p << ' ';
3423   p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3424   p << " do ";
3425   p.printRegion(getAfter());
3426   p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3427 }
3428 
3429 /// Verifies that two ranges of types match, i.e. have the same number of
3430 /// entries and that types are pairwise equals. Reports errors on the given
3431 /// operation in case of mismatch.
3432 template <typename OpTy>
3433 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3434                                            TypeRange right, StringRef message) {
3435   if (left.size() != right.size())
3436     return op.emitOpError("expects the same number of ") << message;
3437 
3438   for (unsigned i = 0, e = left.size(); i < e; ++i) {
3439     if (left[i] != right[i]) {
3440       InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3441                                 << message;
3442       diag.attachNote() << "for argument " << i << ", found " << left[i]
3443                         << " and " << right[i];
3444       return diag;
3445     }
3446   }
3447 
3448   return success();
3449 }
3450 
3451 LogicalResult scf::WhileOp::verify() {
3452   auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3453       *this, getBefore(),
3454       "expects the 'before' region to terminate with 'scf.condition'");
3455   if (!beforeTerminator)
3456     return failure();
3457 
3458   auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3459       *this, getAfter(),
3460       "expects the 'after' region to terminate with 'scf.yield'");
3461   return success(afterTerminator != nullptr);
3462 }
3463 
3464 namespace {
3465 /// Replace uses of the condition within the do block with true, since otherwise
3466 /// the block would not be evaluated.
3467 ///
3468 /// scf.while (..) : (i1, ...) -> ... {
3469 ///  %condition = call @evaluate_condition() : () -> i1
3470 ///  scf.condition(%condition) %condition : i1, ...
3471 /// } do {
3472 /// ^bb0(%arg0: i1, ...):
3473 ///    use(%arg0)
3474 ///    ...
3475 ///
3476 /// becomes
3477 /// scf.while (..) : (i1, ...) -> ... {
3478 ///  %condition = call @evaluate_condition() : () -> i1
3479 ///  scf.condition(%condition) %condition : i1, ...
3480 /// } do {
3481 /// ^bb0(%arg0: i1, ...):
3482 ///    use(%true)
3483 ///    ...
3484 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3485   using OpRewritePattern<WhileOp>::OpRewritePattern;
3486 
3487   LogicalResult matchAndRewrite(WhileOp op,
3488                                 PatternRewriter &rewriter) const override {
3489     auto term = op.getConditionOp();
3490 
3491     // These variables serve to prevent creating duplicate constants
3492     // and hold constant true or false values.
3493     Value constantTrue = nullptr;
3494 
3495     bool replaced = false;
3496     for (auto yieldedAndBlockArgs :
3497          llvm::zip(term.getArgs(), op.getAfterArguments())) {
3498       if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3499         if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3500           if (!constantTrue)
3501             constantTrue = rewriter.create<arith::ConstantOp>(
3502                 op.getLoc(), term.getCondition().getType(),
3503                 rewriter.getBoolAttr(true));
3504 
3505           rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3506                                       constantTrue);
3507           replaced = true;
3508         }
3509       }
3510     }
3511     return success(replaced);
3512   }
3513 };
3514 
3515 /// Remove loop invariant arguments from `before` block of scf.while.
3516 /// A before block argument is considered loop invariant if :-
3517 ///   1. i-th yield operand is equal to the i-th while operand.
3518 ///   2. i-th yield operand is k-th after block argument which is (k+1)-th
3519 ///      condition operand AND this (k+1)-th condition operand is equal to i-th
3520 ///      iter argument/while operand.
3521 /// For the arguments which are removed, their uses inside scf.while
3522 /// are replaced with their corresponding initial value.
3523 ///
3524 /// Eg:
3525 ///    INPUT :-
3526 ///    %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3527 ///                                     ..., %argN_before = %N)
3528 ///           {
3529 ///                ...
3530 ///                scf.condition(%cond) %arg1_before, %arg0_before,
3531 ///                                     %arg2_before, %arg0_before, ...
3532 ///           } do {
3533 ///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3534 ///                  ..., %argK_after):
3535 ///                ...
3536 ///                scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3537 ///           }
3538 ///
3539 ///    OUTPUT :-
3540 ///    %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3541 ///                                     %N)
3542 ///           {
3543 ///                ...
3544 ///                scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3545 ///           } do {
3546 ///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3547 ///                  ..., %argK_after):
3548 ///                ...
3549 ///                scf.yield %arg1_after, ..., %argN
3550 ///           }
3551 ///
3552 ///    EXPLANATION:
3553 ///      We iterate over each yield operand.
3554 ///        1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3555 ///           %arg0_before, which in turn is the 0-th iter argument. So we
3556 ///           remove 0-th before block argument and yield operand, and replace
3557 ///           all uses of the 0-th before block argument with its initial value
3558 ///           %a.
3559 ///        2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3560 ///           value. So we remove this operand and the corresponding before
3561 ///           block argument and replace all uses of 1-th before block argument
3562 ///           with %b.
3563 struct RemoveLoopInvariantArgsFromBeforeBlock
3564     : public OpRewritePattern<WhileOp> {
3565   using OpRewritePattern<WhileOp>::OpRewritePattern;
3566 
3567   LogicalResult matchAndRewrite(WhileOp op,
3568                                 PatternRewriter &rewriter) const override {
3569     Block &afterBlock = *op.getAfterBody();
3570     Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3571     ConditionOp condOp = op.getConditionOp();
3572     OperandRange condOpArgs = condOp.getArgs();
3573     Operation *yieldOp = afterBlock.getTerminator();
3574     ValueRange yieldOpArgs = yieldOp->getOperands();
3575 
3576     bool canSimplify = false;
3577     for (const auto &it :
3578          llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3579       auto index = static_cast<unsigned>(it.index());
3580       auto [initVal, yieldOpArg] = it.value();
3581       // If i-th yield operand is equal to the i-th operand of the scf.while,
3582       // the i-th before block argument is a loop invariant.
3583       if (yieldOpArg == initVal) {
3584         canSimplify = true;
3585         break;
3586       }
3587       // If the i-th yield operand is k-th after block argument, then we check
3588       // if the (k+1)-th condition op operand is equal to either the i-th before
3589       // block argument or the initial value of i-th before block argument. If
3590       // the comparison results `true`, i-th before block argument is a loop
3591       // invariant.
3592       auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3593       if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3594         Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3595         if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3596           canSimplify = true;
3597           break;
3598         }
3599       }
3600     }
3601 
3602     if (!canSimplify)
3603       return failure();
3604 
3605     SmallVector<Value> newInitArgs, newYieldOpArgs;
3606     DenseMap<unsigned, Value> beforeBlockInitValMap;
3607     SmallVector<Location> newBeforeBlockArgLocs;
3608     for (const auto &it :
3609          llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3610       auto index = static_cast<unsigned>(it.index());
3611       auto [initVal, yieldOpArg] = it.value();
3612 
3613       // If i-th yield operand is equal to the i-th operand of the scf.while,
3614       // the i-th before block argument is a loop invariant.
3615       if (yieldOpArg == initVal) {
3616         beforeBlockInitValMap.insert({index, initVal});
3617         continue;
3618       } else {
3619         // If the i-th yield operand is k-th after block argument, then we check
3620         // if the (k+1)-th condition op operand is equal to either the i-th
3621         // before block argument or the initial value of i-th before block
3622         // argument. If the comparison results `true`, i-th before block
3623         // argument is a loop invariant.
3624         auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3625         if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3626           Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3627           if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3628             beforeBlockInitValMap.insert({index, initVal});
3629             continue;
3630           }
3631         }
3632       }
3633       newInitArgs.emplace_back(initVal);
3634       newYieldOpArgs.emplace_back(yieldOpArg);
3635       newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3636     }
3637 
3638     {
3639       OpBuilder::InsertionGuard g(rewriter);
3640       rewriter.setInsertionPoint(yieldOp);
3641       rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3642     }
3643 
3644     auto newWhile =
3645         rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3646 
3647     Block &newBeforeBlock = *rewriter.createBlock(
3648         &newWhile.getBefore(), /*insertPt*/ {},
3649         ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3650 
3651     Block &beforeBlock = *op.getBeforeBody();
3652     SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3653     // For each i-th before block argument we find it's replacement value as :-
3654     //   1. If i-th before block argument is a loop invariant, we fetch it's
3655     //      initial value from `beforeBlockInitValMap` by querying for key `i`.
3656     //   2. Else we fetch j-th new before block argument as the replacement
3657     //      value of i-th before block argument.
3658     for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3659       // If the index 'i' argument was a loop invariant we fetch it's initial
3660       // value from `beforeBlockInitValMap`.
3661       if (beforeBlockInitValMap.count(i) != 0)
3662         newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3663       else
3664         newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3665     }
3666 
3667     rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3668     rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3669                                 newWhile.getAfter().begin());
3670 
3671     rewriter.replaceOp(op, newWhile.getResults());
3672     return success();
3673   }
3674 };
3675 
3676 /// Remove loop invariant value from result (condition op) of scf.while.
3677 /// A value is considered loop invariant if the final value yielded by
3678 /// scf.condition is defined outside of the `before` block. We remove the
3679 /// corresponding argument in `after` block and replace the use with the value.
3680 /// We also replace the use of the corresponding result of scf.while with the
3681 /// value.
3682 ///
3683 /// Eg:
3684 ///    INPUT :-
3685 ///    %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3686 ///                                             %argN_before = %N) {
3687 ///                ...
3688 ///                scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3689 ///           } do {
3690 ///             ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3691 ///                ...
3692 ///                some_func(%arg1_after)
3693 ///                ...
3694 ///                scf.yield %arg0_after, %arg2_after, ..., %argN_after
3695 ///           }
3696 ///
3697 ///    OUTPUT :-
3698 ///    %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3699 ///                ...
3700 ///                scf.condition(%cond) %arg0, %arg1, ..., %argM
3701 ///           } do {
3702 ///             ^bb0(%arg0, %arg3, ..., %argM):
3703 ///                ...
3704 ///                some_func(%a)
3705 ///                ...
3706 ///                scf.yield %arg0, %b, ..., %argN
3707 ///           }
3708 ///
3709 ///     EXPLANATION:
3710 ///       1. The 1-th and 2-th operand of scf.condition are defined outside the
3711 ///          before block of scf.while, so they get removed.
3712 ///       2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3713 ///          replaced by %b.
3714 ///       3. The corresponding after block argument %arg1_after's uses are
3715 ///          replaced by %a and %arg2_after's uses are replaced by %b.
3716 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3717   using OpRewritePattern<WhileOp>::OpRewritePattern;
3718 
3719   LogicalResult matchAndRewrite(WhileOp op,
3720                                 PatternRewriter &rewriter) const override {
3721     Block &beforeBlock = *op.getBeforeBody();
3722     ConditionOp condOp = op.getConditionOp();
3723     OperandRange condOpArgs = condOp.getArgs();
3724 
3725     bool canSimplify = false;
3726     for (Value condOpArg : condOpArgs) {
3727       // Those values not defined within `before` block will be considered as
3728       // loop invariant values. We map the corresponding `index` with their
3729       // value.
3730       if (condOpArg.getParentBlock() != &beforeBlock) {
3731         canSimplify = true;
3732         break;
3733       }
3734     }
3735 
3736     if (!canSimplify)
3737       return failure();
3738 
3739     Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3740 
3741     SmallVector<Value> newCondOpArgs;
3742     SmallVector<Type> newAfterBlockType;
3743     DenseMap<unsigned, Value> condOpInitValMap;
3744     SmallVector<Location> newAfterBlockArgLocs;
3745     for (const auto &it : llvm::enumerate(condOpArgs)) {
3746       auto index = static_cast<unsigned>(it.index());
3747       Value condOpArg = it.value();
3748       // Those values not defined within `before` block will be considered as
3749       // loop invariant values. We map the corresponding `index` with their
3750       // value.
3751       if (condOpArg.getParentBlock() != &beforeBlock) {
3752         condOpInitValMap.insert({index, condOpArg});
3753       } else {
3754         newCondOpArgs.emplace_back(condOpArg);
3755         newAfterBlockType.emplace_back(condOpArg.getType());
3756         newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3757       }
3758     }
3759 
3760     {
3761       OpBuilder::InsertionGuard g(rewriter);
3762       rewriter.setInsertionPoint(condOp);
3763       rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3764                                                newCondOpArgs);
3765     }
3766 
3767     auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3768                                              op.getOperands());
3769 
3770     Block &newAfterBlock =
3771         *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3772                               newAfterBlockType, newAfterBlockArgLocs);
3773 
3774     Block &afterBlock = *op.getAfterBody();
3775     // Since a new scf.condition op was created, we need to fetch the new
3776     // `after` block arguments which will be used while replacing operations of
3777     // previous scf.while's `after` blocks. We'd also be fetching new result
3778     // values too.
3779     SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3780     SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3781     for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3782       Value afterBlockArg, result;
3783       // If index 'i' argument was loop invariant we fetch it's value from the
3784       // `condOpInitMap` map.
3785       if (condOpInitValMap.count(i) != 0) {
3786         afterBlockArg = condOpInitValMap[i];
3787         result = afterBlockArg;
3788       } else {
3789         afterBlockArg = newAfterBlock.getArgument(j);
3790         result = newWhile.getResult(j);
3791         j++;
3792       }
3793       newAfterBlockArgs[i] = afterBlockArg;
3794       newWhileResults[i] = result;
3795     }
3796 
3797     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3798     rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3799                                 newWhile.getBefore().begin());
3800 
3801     rewriter.replaceOp(op, newWhileResults);
3802     return success();
3803   }
3804 };
3805 
3806 /// Remove WhileOp results that are also unused in 'after' block.
3807 ///
3808 ///  %0:2 = scf.while () : () -> (i32, i64) {
3809 ///    %condition = "test.condition"() : () -> i1
3810 ///    %v1 = "test.get_some_value"() : () -> i32
3811 ///    %v2 = "test.get_some_value"() : () -> i64
3812 ///    scf.condition(%condition) %v1, %v2 : i32, i64
3813 ///  } do {
3814 ///  ^bb0(%arg0: i32, %arg1: i64):
3815 ///    "test.use"(%arg0) : (i32) -> ()
3816 ///    scf.yield
3817 ///  }
3818 ///  return %0#0 : i32
3819 ///
3820 /// becomes
3821 ///  %0 = scf.while () : () -> (i32) {
3822 ///    %condition = "test.condition"() : () -> i1
3823 ///    %v1 = "test.get_some_value"() : () -> i32
3824 ///    %v2 = "test.get_some_value"() : () -> i64
3825 ///    scf.condition(%condition) %v1 : i32
3826 ///  } do {
3827 ///  ^bb0(%arg0: i32):
3828 ///    "test.use"(%arg0) : (i32) -> ()
3829 ///    scf.yield
3830 ///  }
3831 ///  return %0 : i32
3832 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3833   using OpRewritePattern<WhileOp>::OpRewritePattern;
3834 
3835   LogicalResult matchAndRewrite(WhileOp op,
3836                                 PatternRewriter &rewriter) const override {
3837     auto term = op.getConditionOp();
3838     auto afterArgs = op.getAfterArguments();
3839     auto termArgs = term.getArgs();
3840 
3841     // Collect results mapping, new terminator args and new result types.
3842     SmallVector<unsigned> newResultsIndices;
3843     SmallVector<Type> newResultTypes;
3844     SmallVector<Value> newTermArgs;
3845     SmallVector<Location> newArgLocs;
3846     bool needUpdate = false;
3847     for (const auto &it :
3848          llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3849       auto i = static_cast<unsigned>(it.index());
3850       Value result = std::get<0>(it.value());
3851       Value afterArg = std::get<1>(it.value());
3852       Value termArg = std::get<2>(it.value());
3853       if (result.use_empty() && afterArg.use_empty()) {
3854         needUpdate = true;
3855       } else {
3856         newResultsIndices.emplace_back(i);
3857         newTermArgs.emplace_back(termArg);
3858         newResultTypes.emplace_back(result.getType());
3859         newArgLocs.emplace_back(result.getLoc());
3860       }
3861     }
3862 
3863     if (!needUpdate)
3864       return failure();
3865 
3866     {
3867       OpBuilder::InsertionGuard g(rewriter);
3868       rewriter.setInsertionPoint(term);
3869       rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3870                                                newTermArgs);
3871     }
3872 
3873     auto newWhile =
3874         rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3875 
3876     Block &newAfterBlock = *rewriter.createBlock(
3877         &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3878 
3879     // Build new results list and new after block args (unused entries will be
3880     // null).
3881     SmallVector<Value> newResults(op.getNumResults());
3882     SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3883     for (const auto &it : llvm::enumerate(newResultsIndices)) {
3884       newResults[it.value()] = newWhile.getResult(it.index());
3885       newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3886     }
3887 
3888     rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3889                                 newWhile.getBefore().begin());
3890 
3891     Block &afterBlock = *op.getAfterBody();
3892     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3893 
3894     rewriter.replaceOp(op, newResults);
3895     return success();
3896   }
3897 };
3898 
3899 /// Replace operations equivalent to the condition in the do block with true,
3900 /// since otherwise the block would not be evaluated.
3901 ///
3902 /// scf.while (..) : (i32, ...) -> ... {
3903 ///  %z = ... : i32
3904 ///  %condition = cmpi pred %z, %a
3905 ///  scf.condition(%condition) %z : i32, ...
3906 /// } do {
3907 /// ^bb0(%arg0: i32, ...):
3908 ///    %condition2 = cmpi pred %arg0, %a
3909 ///    use(%condition2)
3910 ///    ...
3911 ///
3912 /// becomes
3913 /// scf.while (..) : (i32, ...) -> ... {
3914 ///  %z = ... : i32
3915 ///  %condition = cmpi pred %z, %a
3916 ///  scf.condition(%condition) %z : i32, ...
3917 /// } do {
3918 /// ^bb0(%arg0: i32, ...):
3919 ///    use(%true)
3920 ///    ...
3921 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3922   using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3923 
3924   LogicalResult matchAndRewrite(scf::WhileOp op,
3925                                 PatternRewriter &rewriter) const override {
3926     using namespace scf;
3927     auto cond = op.getConditionOp();
3928     auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3929     if (!cmp)
3930       return failure();
3931     bool changed = false;
3932     for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3933       for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3934         if (std::get<0>(tup) != cmp.getOperand(opIdx))
3935           continue;
3936         for (OpOperand &u :
3937              llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3938           auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3939           if (!cmp2)
3940             continue;
3941           // For a binary operator 1-opIdx gets the other side.
3942           if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3943             continue;
3944           bool samePredicate;
3945           if (cmp2.getPredicate() == cmp.getPredicate())
3946             samePredicate = true;
3947           else if (cmp2.getPredicate() ==
3948                    arith::invertPredicate(cmp.getPredicate()))
3949             samePredicate = false;
3950           else
3951             continue;
3952 
3953           rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3954                                                             1);
3955           changed = true;
3956         }
3957       }
3958     }
3959     return success(changed);
3960   }
3961 };
3962 
3963 /// Remove unused init/yield args.
3964 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3965   using OpRewritePattern<WhileOp>::OpRewritePattern;
3966 
3967   LogicalResult matchAndRewrite(WhileOp op,
3968                                 PatternRewriter &rewriter) const override {
3969 
3970     if (!llvm::any_of(op.getBeforeArguments(),
3971                       [](Value arg) { return arg.use_empty(); }))
3972       return rewriter.notifyMatchFailure(op, "No args to remove");
3973 
3974     YieldOp yield = op.getYieldOp();
3975 
3976     // Collect results mapping, new terminator args and new result types.
3977     SmallVector<Value> newYields;
3978     SmallVector<Value> newInits;
3979     llvm::BitVector argsToErase;
3980 
3981     size_t argsCount = op.getBeforeArguments().size();
3982     newYields.reserve(argsCount);
3983     newInits.reserve(argsCount);
3984     argsToErase.reserve(argsCount);
3985     for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3986              op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3987       if (beforeArg.use_empty()) {
3988         argsToErase.push_back(true);
3989       } else {
3990         argsToErase.push_back(false);
3991         newYields.emplace_back(yieldValue);
3992         newInits.emplace_back(initValue);
3993       }
3994     }
3995 
3996     Block &beforeBlock = *op.getBeforeBody();
3997     Block &afterBlock = *op.getAfterBody();
3998 
3999     beforeBlock.eraseArguments(argsToErase);
4000 
4001     Location loc = op.getLoc();
4002     auto newWhileOp =
4003         rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
4004                                  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4005     Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4006     Block &newAfterBlock = *newWhileOp.getAfterBody();
4007 
4008     OpBuilder::InsertionGuard g(rewriter);
4009     rewriter.setInsertionPoint(yield);
4010     rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4011 
4012     rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4013                          newBeforeBlock.getArguments());
4014     rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4015                          newAfterBlock.getArguments());
4016 
4017     rewriter.replaceOp(op, newWhileOp.getResults());
4018     return success();
4019   }
4020 };
4021 
4022 /// Remove duplicated ConditionOp args.
4023 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4024   using OpRewritePattern::OpRewritePattern;
4025 
4026   LogicalResult matchAndRewrite(WhileOp op,
4027                                 PatternRewriter &rewriter) const override {
4028     ConditionOp condOp = op.getConditionOp();
4029     ValueRange condOpArgs = condOp.getArgs();
4030 
4031     llvm::SmallPtrSet<Value, 8> argsSet;
4032     for (Value arg : condOpArgs)
4033       argsSet.insert(arg);
4034 
4035     if (argsSet.size() == condOpArgs.size())
4036       return rewriter.notifyMatchFailure(op, "No results to remove");
4037 
4038     llvm::SmallDenseMap<Value, unsigned> argsMap;
4039     SmallVector<Value> newArgs;
4040     argsMap.reserve(condOpArgs.size());
4041     newArgs.reserve(condOpArgs.size());
4042     for (Value arg : condOpArgs) {
4043       if (!argsMap.count(arg)) {
4044         auto pos = static_cast<unsigned>(argsMap.size());
4045         argsMap.insert({arg, pos});
4046         newArgs.emplace_back(arg);
4047       }
4048     }
4049 
4050     ValueRange argsRange(newArgs);
4051 
4052     Location loc = op.getLoc();
4053     auto newWhileOp = rewriter.create<scf::WhileOp>(
4054         loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4055         /*afterBody*/ nullptr);
4056     Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4057     Block &newAfterBlock = *newWhileOp.getAfterBody();
4058 
4059     SmallVector<Value> afterArgsMapping;
4060     SmallVector<Value> resultsMapping;
4061     for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4062       auto it = argsMap.find(arg);
4063       assert(it != argsMap.end());
4064       auto pos = it->second;
4065       afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4066       resultsMapping.emplace_back(newWhileOp->getResult(pos));
4067     }
4068 
4069     OpBuilder::InsertionGuard g(rewriter);
4070     rewriter.setInsertionPoint(condOp);
4071     rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4072                                              argsRange);
4073 
4074     Block &beforeBlock = *op.getBeforeBody();
4075     Block &afterBlock = *op.getAfterBody();
4076 
4077     rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4078                          newBeforeBlock.getArguments());
4079     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4080     rewriter.replaceOp(op, resultsMapping);
4081     return success();
4082   }
4083 };
4084 
4085 /// If both ranges contain same values return mappping indices from args2 to
4086 /// args1. Otherwise return std::nullopt.
4087 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4088                                                            ValueRange args2) {
4089   if (args1.size() != args2.size())
4090     return std::nullopt;
4091 
4092   SmallVector<unsigned> ret(args1.size());
4093   for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4094     auto it = llvm::find(args2, arg1);
4095     if (it == args2.end())
4096       return std::nullopt;
4097 
4098     ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4099   }
4100 
4101   return ret;
4102 }
4103 
4104 static bool hasDuplicates(ValueRange args) {
4105   llvm::SmallDenseSet<Value> set;
4106   for (Value arg : args) {
4107     if (!set.insert(arg).second)
4108       return true;
4109   }
4110   return false;
4111 }
4112 
4113 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4114 /// `scf.condition` args into same order as block args. Update `after` block
4115 /// args and op result values accordingly.
4116 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4117 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4118   using OpRewritePattern::OpRewritePattern;
4119 
4120   LogicalResult matchAndRewrite(WhileOp loop,
4121                                 PatternRewriter &rewriter) const override {
4122     auto oldBefore = loop.getBeforeBody();
4123     ConditionOp oldTerm = loop.getConditionOp();
4124     ValueRange beforeArgs = oldBefore->getArguments();
4125     ValueRange termArgs = oldTerm.getArgs();
4126     if (beforeArgs == termArgs)
4127       return failure();
4128 
4129     if (hasDuplicates(termArgs))
4130       return failure();
4131 
4132     auto mapping = getArgsMapping(beforeArgs, termArgs);
4133     if (!mapping)
4134       return failure();
4135 
4136     {
4137       OpBuilder::InsertionGuard g(rewriter);
4138       rewriter.setInsertionPoint(oldTerm);
4139       rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4140                                                beforeArgs);
4141     }
4142 
4143     auto oldAfter = loop.getAfterBody();
4144 
4145     SmallVector<Type> newResultTypes(beforeArgs.size());
4146     for (auto &&[i, j] : llvm::enumerate(*mapping))
4147       newResultTypes[j] = loop.getResult(i).getType();
4148 
4149     auto newLoop = rewriter.create<WhileOp>(
4150         loop.getLoc(), newResultTypes, loop.getInits(),
4151         /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4152     auto newBefore = newLoop.getBeforeBody();
4153     auto newAfter = newLoop.getAfterBody();
4154 
4155     SmallVector<Value> newResults(beforeArgs.size());
4156     SmallVector<Value> newAfterArgs(beforeArgs.size());
4157     for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4158       newResults[i] = newLoop.getResult(j);
4159       newAfterArgs[i] = newAfter->getArgument(j);
4160     }
4161 
4162     rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4163                                newBefore->getArguments());
4164     rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4165                                newAfterArgs);
4166 
4167     rewriter.replaceOp(loop, newResults);
4168     return success();
4169   }
4170 };
4171 } // namespace
4172 
4173 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4174                                           MLIRContext *context) {
4175   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4176               RemoveLoopInvariantValueYielded, WhileConditionTruth,
4177               WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4178               WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4179 }
4180 
4181 //===----------------------------------------------------------------------===//
4182 // IndexSwitchOp
4183 //===----------------------------------------------------------------------===//
4184 
4185 /// Parse the case regions and values.
4186 static ParseResult
4187 parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
4188                  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4189   SmallVector<int64_t> caseValues;
4190   while (succeeded(p.parseOptionalKeyword("case"))) {
4191     int64_t value;
4192     Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4193     if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4194       return failure();
4195     caseValues.push_back(value);
4196   }
4197   cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4198   return success();
4199 }
4200 
4201 /// Print the case regions and values.
4202 static void printSwitchCases(OpAsmPrinter &p, Operation *op,
4203                              DenseI64ArrayAttr cases, RegionRange caseRegions) {
4204   for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4205     p.printNewline();
4206     p << "case " << value << ' ';
4207     p.printRegion(*region, /*printEntryBlockArgs=*/false);
4208   }
4209 }
4210 
4211 LogicalResult scf::IndexSwitchOp::verify() {
4212   if (getCases().size() != getCaseRegions().size()) {
4213     return emitOpError("has ")
4214            << getCaseRegions().size() << " case regions but "
4215            << getCases().size() << " case values";
4216   }
4217 
4218   DenseSet<int64_t> valueSet;
4219   for (int64_t value : getCases())
4220     if (!valueSet.insert(value).second)
4221       return emitOpError("has duplicate case value: ") << value;
4222   auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4223     auto yield = dyn_cast<YieldOp>(region.front().back());
4224     if (!yield)
4225       return emitOpError("expected region to end with scf.yield, but got ")
4226              << region.front().back().getName();
4227 
4228     if (yield.getNumOperands() != getNumResults()) {
4229       return (emitOpError("expected each region to return ")
4230               << getNumResults() << " values, but " << name << " returns "
4231               << yield.getNumOperands())
4232                  .attachNote(yield.getLoc())
4233              << "see yield operation here";
4234     }
4235     for (auto [idx, result, operand] :
4236          llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4237                    yield.getOperandTypes())) {
4238       if (result == operand)
4239         continue;
4240       return (emitOpError("expected result #")
4241               << idx << " of each region to be " << result)
4242                  .attachNote(yield.getLoc())
4243              << name << " returns " << operand << " here";
4244     }
4245     return success();
4246   };
4247 
4248   if (failed(verifyRegion(getDefaultRegion(), "default region")))
4249     return failure();
4250   for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4251     if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4252       return failure();
4253 
4254   return success();
4255 }
4256 
4257 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4258 
4259 Block &scf::IndexSwitchOp::getDefaultBlock() {
4260   return getDefaultRegion().front();
4261 }
4262 
4263 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4264   assert(idx < getNumCases() && "case index out-of-bounds");
4265   return getCaseRegions()[idx].front();
4266 }
4267 
4268 void IndexSwitchOp::getSuccessorRegions(
4269     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
4270   // All regions branch back to the parent op.
4271   if (!point.isParent()) {
4272     successors.emplace_back(getResults());
4273     return;
4274   }
4275 
4276   llvm::copy(getRegions(), std::back_inserter(successors));
4277 }
4278 
4279 void IndexSwitchOp::getEntrySuccessorRegions(
4280     ArrayRef<Attribute> operands,
4281     SmallVectorImpl<RegionSuccessor> &successors) {
4282   FoldAdaptor adaptor(operands, *this);
4283 
4284   // If a constant was not provided, all regions are possible successors.
4285   auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4286   if (!arg) {
4287     llvm::copy(getRegions(), std::back_inserter(successors));
4288     return;
4289   }
4290 
4291   // Otherwise, try to find a case with a matching value. If not, the
4292   // default region is the only successor.
4293   for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4294     if (caseValue == arg.getInt()) {
4295       successors.emplace_back(&caseRegion);
4296       return;
4297     }
4298   }
4299   successors.emplace_back(&getDefaultRegion());
4300 }
4301 
4302 void IndexSwitchOp::getRegionInvocationBounds(
4303     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
4304   auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4305   if (!operandValue) {
4306     // All regions are invoked at most once.
4307     bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4308     return;
4309   }
4310 
4311   unsigned liveIndex = getNumRegions() - 1;
4312   const auto *it = llvm::find(getCases(), operandValue.getInt());
4313   if (it != getCases().end())
4314     liveIndex = std::distance(getCases().begin(), it);
4315   for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4316     bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4317 }
4318 
4319 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4320   using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
4321 
4322   LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4323                                 PatternRewriter &rewriter) const override {
4324     // If `op.getArg()` is a constant, select the region that matches with
4325     // the constant value. Use the default region if no matche is found.
4326     std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4327     if (!maybeCst.has_value())
4328       return failure();
4329     int64_t cst = *maybeCst;
4330     int64_t caseIdx, e = op.getNumCases();
4331     for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4332       if (cst == op.getCases()[caseIdx])
4333         break;
4334     }
4335 
4336     Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4337                                              : op.getDefaultRegion();
4338     Block &source = r.front();
4339     Operation *terminator = source.getTerminator();
4340     SmallVector<Value> results = terminator->getOperands();
4341 
4342     rewriter.inlineBlockBefore(&source, op);
4343     rewriter.eraseOp(terminator);
4344     // Replace the operation with a potentially empty list of results.
4345     // Fold mechanism doesn't support the case where the result list is empty.
4346     rewriter.replaceOp(op, results);
4347 
4348     return success();
4349   }
4350 };
4351 
4352 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4353                                                 MLIRContext *context) {
4354   results.add<FoldConstantCase>(context);
4355 }
4356 
4357 //===----------------------------------------------------------------------===//
4358 // TableGen'd op method definitions
4359 //===----------------------------------------------------------------------===//
4360 
4361 #define GET_OP_CLASSES
4362 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
4363