1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/IR/SCF.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/Arith/Utils/Utils.h" 12 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" 13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 14 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/IR/BuiltinAttributes.h" 19 #include "mlir/IR/IRMapping.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Interfaces/FunctionInterfaces.h" 23 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 24 #include "mlir/Transforms/InliningUtils.h" 25 #include "llvm/ADT/MapVector.h" 26 #include "llvm/ADT/SmallPtrSet.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 29 using namespace mlir; 30 using namespace mlir::scf; 31 32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc" 33 34 //===----------------------------------------------------------------------===// 35 // SCFDialect Dialect Interfaces 36 //===----------------------------------------------------------------------===// 37 38 namespace { 39 struct SCFInlinerInterface : public DialectInlinerInterface { 40 using DialectInlinerInterface::DialectInlinerInterface; 41 // We don't have any special restrictions on what can be inlined into 42 // destination regions (e.g. while/conditional bodies). Always allow it. 43 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 44 IRMapping &valueMapping) const final { 45 return true; 46 } 47 // Operations in scf dialect are always legal to inline since they are 48 // pure. 49 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 50 return true; 51 } 52 // Handle the given inlined terminator by replacing it with a new operation 53 // as necessary. Required when the region has only one block. 54 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { 55 auto retValOp = dyn_cast<scf::YieldOp>(op); 56 if (!retValOp) 57 return; 58 59 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) { 60 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue)); 61 } 62 } 63 }; 64 } // namespace 65 66 //===----------------------------------------------------------------------===// 67 // SCFDialect 68 //===----------------------------------------------------------------------===// 69 70 void SCFDialect::initialize() { 71 addOperations< 72 #define GET_OP_LIST 73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc" 74 >(); 75 addInterfaces<SCFInlinerInterface>(); 76 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface, 77 InParallelOp, ReduceReturnOp>(); 78 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp, 79 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp, 80 ForallOp, InParallelOp, WhileOp, YieldOp>(); 81 declarePromisedInterface<ValueBoundsOpInterface, ForOp>(); 82 } 83 84 /// Default callback for IfOp builders. Inserts a yield without arguments. 85 void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) { 86 builder.create<scf::YieldOp>(loc); 87 } 88 89 /// Verifies that the first block of the given `region` is terminated by a 90 /// TerminatorTy. Reports errors on the given operation if it is not the case. 91 template <typename TerminatorTy> 92 static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, 93 StringRef errorMessage) { 94 Operation *terminatorOperation = nullptr; 95 if (!region.empty() && !region.front().empty()) { 96 terminatorOperation = ®ion.front().back(); 97 if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation)) 98 return yield; 99 } 100 auto diag = op->emitOpError(errorMessage); 101 if (terminatorOperation) 102 diag.attachNote(terminatorOperation->getLoc()) << "terminator here"; 103 return nullptr; 104 } 105 106 //===----------------------------------------------------------------------===// 107 // ExecuteRegionOp 108 //===----------------------------------------------------------------------===// 109 110 /// Replaces the given op with the contents of the given single-block region, 111 /// using the operands of the block terminator to replace operation results. 112 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, 113 Region ®ion, ValueRange blockArgs = {}) { 114 assert(llvm::hasSingleElement(region) && "expected single-region block"); 115 Block *block = ®ion.front(); 116 Operation *terminator = block->getTerminator(); 117 ValueRange results = terminator->getOperands(); 118 rewriter.inlineBlockBefore(block, op, blockArgs); 119 rewriter.replaceOp(op, results); 120 rewriter.eraseOp(terminator); 121 } 122 123 /// 124 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{` 125 /// block+ 126 /// `}` 127 /// 128 /// Example: 129 /// scf.execute_region -> i32 { 130 /// %idx = load %rI[%i] : memref<128xi32> 131 /// return %idx : i32 132 /// } 133 /// 134 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, 135 OperationState &result) { 136 if (parser.parseOptionalArrowTypeList(result.types)) 137 return failure(); 138 139 // Introduce the body region and parse it. 140 Region *body = result.addRegion(); 141 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || 142 parser.parseOptionalAttrDict(result.attributes)) 143 return failure(); 144 145 return success(); 146 } 147 148 void ExecuteRegionOp::print(OpAsmPrinter &p) { 149 p.printOptionalArrowTypeList(getResultTypes()); 150 151 p << ' '; 152 p.printRegion(getRegion(), 153 /*printEntryBlockArgs=*/false, 154 /*printBlockTerminators=*/true); 155 156 p.printOptionalAttrDict((*this)->getAttrs()); 157 } 158 159 LogicalResult ExecuteRegionOp::verify() { 160 if (getRegion().empty()) 161 return emitOpError("region needs to have at least one block"); 162 if (getRegion().front().getNumArguments() > 0) 163 return emitOpError("region cannot have any arguments"); 164 return success(); 165 } 166 167 // Inline an ExecuteRegionOp if it only contains one block. 168 // "test.foo"() : () -> () 169 // %v = scf.execute_region -> i64 { 170 // %x = "test.val"() : () -> i64 171 // scf.yield %x : i64 172 // } 173 // "test.bar"(%v) : (i64) -> () 174 // 175 // becomes 176 // 177 // "test.foo"() : () -> () 178 // %x = "test.val"() : () -> i64 179 // "test.bar"(%x) : (i64) -> () 180 // 181 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { 182 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; 183 184 LogicalResult matchAndRewrite(ExecuteRegionOp op, 185 PatternRewriter &rewriter) const override { 186 if (!llvm::hasSingleElement(op.getRegion())) 187 return failure(); 188 replaceOpWithRegion(rewriter, op, op.getRegion()); 189 return success(); 190 } 191 }; 192 193 // Inline an ExecuteRegionOp if its parent can contain multiple blocks. 194 // TODO generalize the conditions for operations which can be inlined into. 195 // func @func_execute_region_elim() { 196 // "test.foo"() : () -> () 197 // %v = scf.execute_region -> i64 { 198 // %c = "test.cmp"() : () -> i1 199 // cf.cond_br %c, ^bb2, ^bb3 200 // ^bb2: 201 // %x = "test.val1"() : () -> i64 202 // cf.br ^bb4(%x : i64) 203 // ^bb3: 204 // %y = "test.val2"() : () -> i64 205 // cf.br ^bb4(%y : i64) 206 // ^bb4(%z : i64): 207 // scf.yield %z : i64 208 // } 209 // "test.bar"(%v) : (i64) -> () 210 // return 211 // } 212 // 213 // becomes 214 // 215 // func @func_execute_region_elim() { 216 // "test.foo"() : () -> () 217 // %c = "test.cmp"() : () -> i1 218 // cf.cond_br %c, ^bb1, ^bb2 219 // ^bb1: // pred: ^bb0 220 // %x = "test.val1"() : () -> i64 221 // cf.br ^bb3(%x : i64) 222 // ^bb2: // pred: ^bb0 223 // %y = "test.val2"() : () -> i64 224 // cf.br ^bb3(%y : i64) 225 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2 226 // "test.bar"(%z) : (i64) -> () 227 // return 228 // } 229 // 230 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { 231 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; 232 233 LogicalResult matchAndRewrite(ExecuteRegionOp op, 234 PatternRewriter &rewriter) const override { 235 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp())) 236 return failure(); 237 238 Block *prevBlock = op->getBlock(); 239 Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator()); 240 rewriter.setInsertionPointToEnd(prevBlock); 241 242 rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front()); 243 244 for (Block &blk : op.getRegion()) { 245 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) { 246 rewriter.setInsertionPoint(yieldOp); 247 rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock, 248 yieldOp.getResults()); 249 rewriter.eraseOp(yieldOp); 250 } 251 } 252 253 rewriter.inlineRegionBefore(op.getRegion(), postBlock); 254 SmallVector<Value> blockArgs; 255 256 for (auto res : op.getResults()) 257 blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc())); 258 259 rewriter.replaceOp(op, blockArgs); 260 return success(); 261 } 262 }; 263 264 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, 265 MLIRContext *context) { 266 results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); 267 } 268 269 void ExecuteRegionOp::getSuccessorRegions( 270 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 271 // If the predecessor is the ExecuteRegionOp, branch into the body. 272 if (point.isParent()) { 273 regions.push_back(RegionSuccessor(&getRegion())); 274 return; 275 } 276 277 // Otherwise, the region branches back to the parent operation. 278 regions.push_back(RegionSuccessor(getResults())); 279 } 280 281 //===----------------------------------------------------------------------===// 282 // ConditionOp 283 //===----------------------------------------------------------------------===// 284 285 MutableOperandRange 286 ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { 287 assert((point.isParent() || point == getParentOp().getAfter()) && 288 "condition op can only exit the loop or branch to the after" 289 "region"); 290 // Pass all operands except the condition to the successor region. 291 return getArgsMutable(); 292 } 293 294 void ConditionOp::getSuccessorRegions( 295 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) { 296 FoldAdaptor adaptor(operands, *this); 297 298 WhileOp whileOp = getParentOp(); 299 300 // Condition can either lead to the after region or back to the parent op 301 // depending on whether the condition is true or not. 302 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition()); 303 if (!boolAttr || boolAttr.getValue()) 304 regions.emplace_back(&whileOp.getAfter(), 305 whileOp.getAfter().getArguments()); 306 if (!boolAttr || !boolAttr.getValue()) 307 regions.emplace_back(whileOp.getResults()); 308 } 309 310 //===----------------------------------------------------------------------===// 311 // ForOp 312 //===----------------------------------------------------------------------===// 313 314 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, 315 Value ub, Value step, ValueRange initArgs, 316 BodyBuilderFn bodyBuilder) { 317 OpBuilder::InsertionGuard guard(builder); 318 319 result.addOperands({lb, ub, step}); 320 result.addOperands(initArgs); 321 for (Value v : initArgs) 322 result.addTypes(v.getType()); 323 Type t = lb.getType(); 324 Region *bodyRegion = result.addRegion(); 325 Block *bodyBlock = builder.createBlock(bodyRegion); 326 bodyBlock->addArgument(t, result.location); 327 for (Value v : initArgs) 328 bodyBlock->addArgument(v.getType(), v.getLoc()); 329 330 // Create the default terminator if the builder is not provided and if the 331 // iteration arguments are not provided. Otherwise, leave this to the caller 332 // because we don't know which values to return from the loop. 333 if (initArgs.empty() && !bodyBuilder) { 334 ForOp::ensureTerminator(*bodyRegion, builder, result.location); 335 } else if (bodyBuilder) { 336 OpBuilder::InsertionGuard guard(builder); 337 builder.setInsertionPointToStart(bodyBlock); 338 bodyBuilder(builder, result.location, bodyBlock->getArgument(0), 339 bodyBlock->getArguments().drop_front()); 340 } 341 } 342 343 LogicalResult ForOp::verify() { 344 // Check that the number of init args and op results is the same. 345 if (getInitArgs().size() != getNumResults()) 346 return emitOpError( 347 "mismatch in number of loop-carried values and defined values"); 348 349 return success(); 350 } 351 352 LogicalResult ForOp::verifyRegions() { 353 // Check that the body defines as single block argument for the induction 354 // variable. 355 if (getInductionVar().getType() != getLowerBound().getType()) 356 return emitOpError( 357 "expected induction variable to be same type as bounds and step"); 358 359 if (getNumRegionIterArgs() != getNumResults()) 360 return emitOpError( 361 "mismatch in number of basic block args and defined values"); 362 363 auto initArgs = getInitArgs(); 364 auto iterArgs = getRegionIterArgs(); 365 auto opResults = getResults(); 366 unsigned i = 0; 367 for (auto e : llvm::zip(initArgs, iterArgs, opResults)) { 368 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 369 return emitOpError() << "types mismatch between " << i 370 << "th iter operand and defined value"; 371 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 372 return emitOpError() << "types mismatch between " << i 373 << "th iter region arg and defined value"; 374 375 ++i; 376 } 377 return success(); 378 } 379 380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() { 381 return SmallVector<Value>{getInductionVar()}; 382 } 383 384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() { 385 return SmallVector<OpFoldResult>{OpFoldResult(getLowerBound())}; 386 } 387 388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() { 389 return SmallVector<OpFoldResult>{OpFoldResult(getStep())}; 390 } 391 392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() { 393 return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())}; 394 } 395 396 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); } 397 398 /// Promotes the loop body of a forOp to its containing block if the forOp 399 /// it can be determined that the loop has a single iteration. 400 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) { 401 std::optional<int64_t> tripCount = 402 constantTripCount(getLowerBound(), getUpperBound(), getStep()); 403 if (!tripCount.has_value() || tripCount != 1) 404 return failure(); 405 406 // Replace all results with the yielded values. 407 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator()); 408 rewriter.replaceAllUsesWith(getResults(), getYieldedValues()); 409 410 // Replace block arguments with lower bound (replacement for IV) and 411 // iter_args. 412 SmallVector<Value> bbArgReplacements; 413 bbArgReplacements.push_back(getLowerBound()); 414 llvm::append_range(bbArgReplacements, getInitArgs()); 415 416 // Move the loop body operations to the loop's containing block. 417 rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(), 418 getOperation()->getIterator(), bbArgReplacements); 419 420 // Erase the old terminator and the loop. 421 rewriter.eraseOp(yieldOp); 422 rewriter.eraseOp(*this); 423 424 return success(); 425 } 426 427 /// Prints the initialization list in the form of 428 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>) 429 /// where 'inner' values are assumed to be region arguments and 'outer' values 430 /// are regular SSA values. 431 static void printInitializationList(OpAsmPrinter &p, 432 Block::BlockArgListType blocksArgs, 433 ValueRange initializers, 434 StringRef prefix = "") { 435 assert(blocksArgs.size() == initializers.size() && 436 "expected same length of arguments and initializers"); 437 if (initializers.empty()) 438 return; 439 440 p << prefix << '('; 441 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { 442 p << std::get<0>(it) << " = " << std::get<1>(it); 443 }); 444 p << ")"; 445 } 446 447 void ForOp::print(OpAsmPrinter &p) { 448 p << " " << getInductionVar() << " = " << getLowerBound() << " to " 449 << getUpperBound() << " step " << getStep(); 450 451 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args"); 452 if (!getInitArgs().empty()) 453 p << " -> (" << getInitArgs().getTypes() << ')'; 454 p << ' '; 455 if (Type t = getInductionVar().getType(); !t.isIndex()) 456 p << " : " << t << ' '; 457 p.printRegion(getRegion(), 458 /*printEntryBlockArgs=*/false, 459 /*printBlockTerminators=*/!getInitArgs().empty()); 460 p.printOptionalAttrDict((*this)->getAttrs()); 461 } 462 463 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { 464 auto &builder = parser.getBuilder(); 465 Type type; 466 467 OpAsmParser::Argument inductionVariable; 468 OpAsmParser::UnresolvedOperand lb, ub, step; 469 470 // Parse the induction variable followed by '='. 471 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() || 472 // Parse loop bounds. 473 parser.parseOperand(lb) || parser.parseKeyword("to") || 474 parser.parseOperand(ub) || parser.parseKeyword("step") || 475 parser.parseOperand(step)) 476 return failure(); 477 478 // Parse the optional initial iteration arguments. 479 SmallVector<OpAsmParser::Argument, 4> regionArgs; 480 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; 481 regionArgs.push_back(inductionVariable); 482 483 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); 484 if (hasIterArgs) { 485 // Parse assignment list and results type list. 486 if (parser.parseAssignmentList(regionArgs, operands) || 487 parser.parseArrowTypeList(result.types)) 488 return failure(); 489 } 490 491 if (regionArgs.size() != result.types.size() + 1) 492 return parser.emitError( 493 parser.getNameLoc(), 494 "mismatch in number of loop-carried values and defined values"); 495 496 // Parse optional type, else assume Index. 497 if (parser.parseOptionalColon()) 498 type = builder.getIndexType(); 499 else if (parser.parseType(type)) 500 return failure(); 501 502 // Resolve input operands. 503 regionArgs.front().type = type; 504 if (parser.resolveOperand(lb, type, result.operands) || 505 parser.resolveOperand(ub, type, result.operands) || 506 parser.resolveOperand(step, type, result.operands)) 507 return failure(); 508 if (hasIterArgs) { 509 for (auto argOperandType : 510 llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) { 511 Type type = std::get<2>(argOperandType); 512 std::get<0>(argOperandType).type = type; 513 if (parser.resolveOperand(std::get<1>(argOperandType), type, 514 result.operands)) 515 return failure(); 516 } 517 } 518 519 // Parse the body region. 520 Region *body = result.addRegion(); 521 if (parser.parseRegion(*body, regionArgs)) 522 return failure(); 523 524 ForOp::ensureTerminator(*body, builder, result.location); 525 526 // Parse the optional attribute list. 527 if (parser.parseOptionalAttrDict(result.attributes)) 528 return failure(); 529 530 return success(); 531 } 532 533 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; } 534 535 Block::BlockArgListType ForOp::getRegionIterArgs() { 536 return getBody()->getArguments().drop_front(getNumInductionVars()); 537 } 538 539 MutableArrayRef<OpOperand> ForOp::getInitsMutable() { 540 return getInitArgsMutable(); 541 } 542 543 FailureOr<LoopLikeOpInterface> 544 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter, 545 ValueRange newInitOperands, 546 bool replaceInitOperandUsesInLoop, 547 const NewYieldValuesFn &newYieldValuesFn) { 548 // Create a new loop before the existing one, with the extra operands. 549 OpBuilder::InsertionGuard g(rewriter); 550 rewriter.setInsertionPoint(getOperation()); 551 auto inits = llvm::to_vector(getInitArgs()); 552 inits.append(newInitOperands.begin(), newInitOperands.end()); 553 scf::ForOp newLoop = rewriter.create<scf::ForOp>( 554 getLoc(), getLowerBound(), getUpperBound(), getStep(), inits, 555 [](OpBuilder &, Location, Value, ValueRange) {}); 556 newLoop->setAttrs(getPrunedAttributeList(getOperation(), {})); 557 558 // Generate the new yield values and append them to the scf.yield operation. 559 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator()); 560 ArrayRef<BlockArgument> newIterArgs = 561 newLoop.getBody()->getArguments().take_back(newInitOperands.size()); 562 { 563 OpBuilder::InsertionGuard g(rewriter); 564 rewriter.setInsertionPoint(yieldOp); 565 SmallVector<Value> newYieldedValues = 566 newYieldValuesFn(rewriter, getLoc(), newIterArgs); 567 assert(newInitOperands.size() == newYieldedValues.size() && 568 "expected as many new yield values as new iter operands"); 569 rewriter.modifyOpInPlace(yieldOp, [&]() { 570 yieldOp.getResultsMutable().append(newYieldedValues); 571 }); 572 } 573 574 // Move the loop body to the new op. 575 rewriter.mergeBlocks(getBody(), newLoop.getBody(), 576 newLoop.getBody()->getArguments().take_front( 577 getBody()->getNumArguments())); 578 579 if (replaceInitOperandUsesInLoop) { 580 // Replace all uses of `newInitOperands` with the corresponding basic block 581 // arguments. 582 for (auto it : llvm::zip(newInitOperands, newIterArgs)) { 583 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it), 584 [&](OpOperand &use) { 585 Operation *user = use.getOwner(); 586 return newLoop->isProperAncestor(user); 587 }); 588 } 589 } 590 591 // Replace the old loop. 592 rewriter.replaceOp(getOperation(), 593 newLoop->getResults().take_front(getNumResults())); 594 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 595 } 596 597 ForOp mlir::scf::getForInductionVarOwner(Value val) { 598 auto ivArg = llvm::dyn_cast<BlockArgument>(val); 599 if (!ivArg) 600 return ForOp(); 601 assert(ivArg.getOwner() && "unlinked block argument"); 602 auto *containingOp = ivArg.getOwner()->getParentOp(); 603 return dyn_cast_or_null<ForOp>(containingOp); 604 } 605 606 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { 607 return getInitArgs(); 608 } 609 610 void ForOp::getSuccessorRegions(RegionBranchPoint point, 611 SmallVectorImpl<RegionSuccessor> ®ions) { 612 // Both the operation itself and the region may be branching into the body or 613 // back into the operation itself. It is possible for loop not to enter the 614 // body. 615 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); 616 regions.push_back(RegionSuccessor(getResults())); 617 } 618 619 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } 620 621 /// Promotes the loop body of a forallOp to its containing block if it can be 622 /// determined that the loop has a single iteration. 623 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { 624 for (auto [lb, ub, step] : 625 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) { 626 auto tripCount = constantTripCount(lb, ub, step); 627 if (!tripCount.has_value() || *tripCount != 1) 628 return failure(); 629 } 630 631 promote(rewriter, *this); 632 return success(); 633 } 634 635 Block::BlockArgListType ForallOp::getRegionIterArgs() { 636 return getBody()->getArguments().drop_front(getRank()); 637 } 638 639 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() { 640 return getOutputsMutable(); 641 } 642 643 /// Promotes the loop body of a scf::ForallOp to its containing block. 644 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { 645 OpBuilder::InsertionGuard g(rewriter); 646 scf::InParallelOp terminator = forallOp.getTerminator(); 647 648 // Replace block arguments with lower bounds (replacements for IVs) and 649 // outputs. 650 SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter); 651 bbArgReplacements.append(forallOp.getOutputs().begin(), 652 forallOp.getOutputs().end()); 653 654 // Move the loop body operations to the loop's containing block. 655 rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(), 656 forallOp->getIterator(), bbArgReplacements); 657 658 // Replace the terminator with tensor.insert_slice ops. 659 rewriter.setInsertionPointAfter(forallOp); 660 SmallVector<Value> results; 661 results.reserve(forallOp.getResults().size()); 662 for (auto &yieldingOp : terminator.getYieldingOps()) { 663 auto parallelInsertSliceOp = 664 cast<tensor::ParallelInsertSliceOp>(yieldingOp); 665 666 Value dst = parallelInsertSliceOp.getDest(); 667 Value src = parallelInsertSliceOp.getSource(); 668 if (llvm::isa<TensorType>(src.getType())) { 669 results.push_back(rewriter.create<tensor::InsertSliceOp>( 670 forallOp.getLoc(), dst.getType(), src, dst, 671 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(), 672 parallelInsertSliceOp.getStrides(), 673 parallelInsertSliceOp.getStaticOffsets(), 674 parallelInsertSliceOp.getStaticSizes(), 675 parallelInsertSliceOp.getStaticStrides())); 676 } else { 677 llvm_unreachable("unsupported terminator"); 678 } 679 } 680 rewriter.replaceAllUsesWith(forallOp.getResults(), results); 681 682 // Erase the old terminator and the loop. 683 rewriter.eraseOp(terminator); 684 rewriter.eraseOp(forallOp); 685 } 686 687 LoopNest mlir::scf::buildLoopNest( 688 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, 689 ValueRange steps, ValueRange iterArgs, 690 function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> 691 bodyBuilder) { 692 assert(lbs.size() == ubs.size() && 693 "expected the same number of lower and upper bounds"); 694 assert(lbs.size() == steps.size() && 695 "expected the same number of lower bounds and steps"); 696 697 // If there are no bounds, call the body-building function and return early. 698 if (lbs.empty()) { 699 ValueVector results = 700 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs) 701 : ValueVector(); 702 assert(results.size() == iterArgs.size() && 703 "loop nest body must return as many values as loop has iteration " 704 "arguments"); 705 return LoopNest{{}, std::move(results)}; 706 } 707 708 // First, create the loop structure iteratively using the body-builder 709 // callback of `ForOp::build`. Do not create `YieldOp`s yet. 710 OpBuilder::InsertionGuard guard(builder); 711 SmallVector<scf::ForOp, 4> loops; 712 SmallVector<Value, 4> ivs; 713 loops.reserve(lbs.size()); 714 ivs.reserve(lbs.size()); 715 ValueRange currentIterArgs = iterArgs; 716 Location currentLoc = loc; 717 for (unsigned i = 0, e = lbs.size(); i < e; ++i) { 718 auto loop = builder.create<scf::ForOp>( 719 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs, 720 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, 721 ValueRange args) { 722 ivs.push_back(iv); 723 // It is safe to store ValueRange args because it points to block 724 // arguments of a loop operation that we also own. 725 currentIterArgs = args; 726 currentLoc = nestedLoc; 727 }); 728 // Set the builder to point to the body of the newly created loop. We don't 729 // do this in the callback because the builder is reset when the callback 730 // returns. 731 builder.setInsertionPointToStart(loop.getBody()); 732 loops.push_back(loop); 733 } 734 735 // For all loops but the innermost, yield the results of the nested loop. 736 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) { 737 builder.setInsertionPointToEnd(loops[i].getBody()); 738 builder.create<scf::YieldOp>(loc, loops[i + 1].getResults()); 739 } 740 741 // In the body of the innermost loop, call the body building function if any 742 // and yield its results. 743 builder.setInsertionPointToStart(loops.back().getBody()); 744 ValueVector results = bodyBuilder 745 ? bodyBuilder(builder, currentLoc, ivs, 746 loops.back().getRegionIterArgs()) 747 : ValueVector(); 748 assert(results.size() == iterArgs.size() && 749 "loop nest body must return as many values as loop has iteration " 750 "arguments"); 751 builder.setInsertionPointToEnd(loops.back().getBody()); 752 builder.create<scf::YieldOp>(loc, results); 753 754 // Return the loops. 755 ValueVector nestResults; 756 llvm::copy(loops.front().getResults(), std::back_inserter(nestResults)); 757 return LoopNest{std::move(loops), std::move(nestResults)}; 758 } 759 760 LoopNest mlir::scf::buildLoopNest( 761 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, 762 ValueRange steps, 763 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { 764 // Delegate to the main function by wrapping the body builder. 765 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt, 766 [&bodyBuilder](OpBuilder &nestedBuilder, 767 Location nestedLoc, ValueRange ivs, 768 ValueRange) -> ValueVector { 769 if (bodyBuilder) 770 bodyBuilder(nestedBuilder, nestedLoc, ivs); 771 return {}; 772 }); 773 } 774 775 SmallVector<Value> 776 mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, 777 OpOperand &operand, Value replacement, 778 const ValueTypeCastFnTy &castFn) { 779 assert(operand.getOwner() == forOp); 780 Type oldType = operand.get().getType(), newType = replacement.getType(); 781 782 // 1. Create new iter operands, exactly 1 is replaced. 783 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() && 784 "expected an iter OpOperand"); 785 assert(operand.get().getType() != replacement.getType() && 786 "Expected a different type"); 787 SmallVector<Value> newIterOperands; 788 for (OpOperand &opOperand : forOp.getInitArgsMutable()) { 789 if (opOperand.getOperandNumber() == operand.getOperandNumber()) { 790 newIterOperands.push_back(replacement); 791 continue; 792 } 793 newIterOperands.push_back(opOperand.get()); 794 } 795 796 // 2. Create the new forOp shell. 797 scf::ForOp newForOp = rewriter.create<scf::ForOp>( 798 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 799 forOp.getStep(), newIterOperands); 800 newForOp->setAttrs(forOp->getAttrs()); 801 Block &newBlock = newForOp.getRegion().front(); 802 SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(), 803 newBlock.getArguments().end()); 804 805 // 3. Inject an incoming cast op at the beginning of the block for the bbArg 806 // corresponding to the `replacement` value. 807 OpBuilder::InsertionGuard g(rewriter); 808 rewriter.setInsertionPointToStart(&newBlock); 809 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg( 810 &newForOp->getOpOperand(operand.getOperandNumber())); 811 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg); 812 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn; 813 814 // 4. Steal the old block ops, mapping to the newBlockTransferArgs. 815 Block &oldBlock = forOp.getRegion().front(); 816 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); 817 818 // 5. Inject an outgoing cast op at the end of the block and yield it instead. 819 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); 820 rewriter.setInsertionPoint(clonedYieldOp); 821 unsigned yieldIdx = 822 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars(); 823 Value castOut = castFn(rewriter, newForOp.getLoc(), newType, 824 clonedYieldOp.getOperand(yieldIdx)); 825 SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands(); 826 newYieldOperands[yieldIdx] = castOut; 827 rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands); 828 rewriter.eraseOp(clonedYieldOp); 829 830 // 6. Inject an outgoing cast op after the forOp. 831 rewriter.setInsertionPointAfter(newForOp); 832 SmallVector<Value> newResults = newForOp.getResults(); 833 newResults[yieldIdx] = 834 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]); 835 836 return newResults; 837 } 838 839 namespace { 840 // Fold away ForOp iter arguments when: 841 // 1) The op yields the iter arguments. 842 // 2) The argument's corresponding outer region iterators (inputs) are yielded. 843 // 3) The iter arguments have no use and the corresponding (operation) results 844 // have no use. 845 // 846 // These arguments must be defined outside of 847 // the ForOp region and can just be forwarded after simplifying the op inits, 848 // yields and returns. 849 // 850 // The implementation uses `inlineBlockBefore` to steal the content of the 851 // original ForOp and avoid cloning. 852 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { 853 using OpRewritePattern<scf::ForOp>::OpRewritePattern; 854 855 LogicalResult matchAndRewrite(scf::ForOp forOp, 856 PatternRewriter &rewriter) const final { 857 bool canonicalize = false; 858 859 // An internal flat vector of block transfer 860 // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to 861 // transformed block argument mappings. This plays the role of a 862 // IRMapping for the particular use case of calling into 863 // `inlineBlockBefore`. 864 int64_t numResults = forOp.getNumResults(); 865 SmallVector<bool, 4> keepMask; 866 keepMask.reserve(numResults); 867 SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues, 868 newResultValues; 869 newBlockTransferArgs.reserve(1 + numResults); 870 newBlockTransferArgs.push_back(Value()); // iv placeholder with null value 871 newIterArgs.reserve(forOp.getInitArgs().size()); 872 newYieldValues.reserve(numResults); 873 newResultValues.reserve(numResults); 874 for (auto [init, arg, result, yielded] : 875 llvm::zip(forOp.getInitArgs(), // iter from outside 876 forOp.getRegionIterArgs(), // iter inside region 877 forOp.getResults(), // op results 878 forOp.getYieldedValues() // iter yield 879 )) { 880 // Forwarded is `true` when: 881 // 1) The region `iter` argument is yielded. 882 // 2) The region `iter` argument the corresponding input is yielded. 883 // 3) The region `iter` argument has no use, and the corresponding op 884 // result has no use. 885 bool forwarded = (arg == yielded) || (init == yielded) || 886 (arg.use_empty() && result.use_empty()); 887 keepMask.push_back(!forwarded); 888 canonicalize |= forwarded; 889 if (forwarded) { 890 newBlockTransferArgs.push_back(init); 891 newResultValues.push_back(init); 892 continue; 893 } 894 newIterArgs.push_back(init); 895 newYieldValues.push_back(yielded); 896 newBlockTransferArgs.push_back(Value()); // placeholder with null value 897 newResultValues.push_back(Value()); // placeholder with null value 898 } 899 900 if (!canonicalize) 901 return failure(); 902 903 scf::ForOp newForOp = rewriter.create<scf::ForOp>( 904 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 905 forOp.getStep(), newIterArgs); 906 newForOp->setAttrs(forOp->getAttrs()); 907 Block &newBlock = newForOp.getRegion().front(); 908 909 // Replace the null placeholders with newly constructed values. 910 newBlockTransferArgs[0] = newBlock.getArgument(0); // iv 911 for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size(); 912 idx != e; ++idx) { 913 Value &blockTransferArg = newBlockTransferArgs[1 + idx]; 914 Value &newResultVal = newResultValues[idx]; 915 assert((blockTransferArg && newResultVal) || 916 (!blockTransferArg && !newResultVal)); 917 if (!blockTransferArg) { 918 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx]; 919 newResultVal = newForOp.getResult(collapsedIdx++); 920 } 921 } 922 923 Block &oldBlock = forOp.getRegion().front(); 924 assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() && 925 "unexpected argument size mismatch"); 926 927 // No results case: the scf::ForOp builder already created a zero 928 // result terminator. Merge before this terminator and just get rid of the 929 // original terminator that has been merged in. 930 if (newIterArgs.empty()) { 931 auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); 932 rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs); 933 rewriter.eraseOp(newBlock.getTerminator()->getPrevNode()); 934 rewriter.replaceOp(forOp, newResultValues); 935 return success(); 936 } 937 938 // No terminator case: merge and rewrite the merged terminator. 939 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) { 940 OpBuilder::InsertionGuard g(rewriter); 941 rewriter.setInsertionPoint(mergedTerminator); 942 SmallVector<Value, 4> filteredOperands; 943 filteredOperands.reserve(newResultValues.size()); 944 for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx) 945 if (keepMask[idx]) 946 filteredOperands.push_back(mergedTerminator.getOperand(idx)); 947 rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(), 948 filteredOperands); 949 }; 950 951 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); 952 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); 953 cloneFilteredTerminator(mergedYieldOp); 954 rewriter.eraseOp(mergedYieldOp); 955 rewriter.replaceOp(forOp, newResultValues); 956 return success(); 957 } 958 }; 959 960 /// Util function that tries to compute a constant diff between u and l. 961 /// Returns std::nullopt when the difference between two AffineValueMap is 962 /// dynamic. 963 static std::optional<int64_t> computeConstDiff(Value l, Value u) { 964 IntegerAttr clb, cub; 965 if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) { 966 llvm::APInt lbValue = clb.getValue(); 967 llvm::APInt ubValue = cub.getValue(); 968 return (ubValue - lbValue).getSExtValue(); 969 } 970 971 // Else a simple pattern match for x + c or c + x 972 llvm::APInt diff; 973 if (matchPattern( 974 u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) || 975 matchPattern( 976 u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l)))) 977 return diff.getSExtValue(); 978 return std::nullopt; 979 } 980 981 /// Rewriting pattern that erases loops that are known not to iterate, replaces 982 /// single-iteration loops with their bodies, and removes empty loops that 983 /// iterate at least once and only return values defined outside of the loop. 984 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { 985 using OpRewritePattern<ForOp>::OpRewritePattern; 986 987 LogicalResult matchAndRewrite(ForOp op, 988 PatternRewriter &rewriter) const override { 989 // If the upper bound is the same as the lower bound, the loop does not 990 // iterate, just remove it. 991 if (op.getLowerBound() == op.getUpperBound()) { 992 rewriter.replaceOp(op, op.getInitArgs()); 993 return success(); 994 } 995 996 std::optional<int64_t> diff = 997 computeConstDiff(op.getLowerBound(), op.getUpperBound()); 998 if (!diff) 999 return failure(); 1000 1001 // If the loop is known to have 0 iterations, remove it. 1002 if (*diff <= 0) { 1003 rewriter.replaceOp(op, op.getInitArgs()); 1004 return success(); 1005 } 1006 1007 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep(); 1008 if (!maybeStepValue) 1009 return failure(); 1010 1011 // If the loop is known to have 1 iteration, inline its body and remove the 1012 // loop. 1013 llvm::APInt stepValue = *maybeStepValue; 1014 if (stepValue.sge(*diff)) { 1015 SmallVector<Value, 4> blockArgs; 1016 blockArgs.reserve(op.getInitArgs().size() + 1); 1017 blockArgs.push_back(op.getLowerBound()); 1018 llvm::append_range(blockArgs, op.getInitArgs()); 1019 replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs); 1020 return success(); 1021 } 1022 1023 // Now we are left with loops that have more than 1 iterations. 1024 Block &block = op.getRegion().front(); 1025 if (!llvm::hasSingleElement(block)) 1026 return failure(); 1027 // If the loop is empty, iterates at least once, and only returns values 1028 // defined outside of the loop, remove it and replace it with yield values. 1029 if (llvm::any_of(op.getYieldedValues(), 1030 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); })) 1031 return failure(); 1032 rewriter.replaceOp(op, op.getYieldedValues()); 1033 return success(); 1034 } 1035 }; 1036 1037 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing 1038 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for: 1039 /// 1040 /// ``` 1041 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32> 1042 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) 1043 /// -> (tensor<?x?xf32>) { 1044 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32> 1045 /// scf.yield %2 : tensor<?x?xf32> 1046 /// } 1047 /// use_of(%1) 1048 /// ``` 1049 /// 1050 /// folds into: 1051 /// 1052 /// ``` 1053 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0) 1054 /// -> (tensor<32x1024xf32>) { 1055 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32> 1056 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32> 1057 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32> 1058 /// scf.yield %4 : tensor<32x1024xf32> 1059 /// } 1060 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32> 1061 /// use_of(%1) 1062 /// ``` 1063 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> { 1064 using OpRewritePattern<ForOp>::OpRewritePattern; 1065 1066 LogicalResult matchAndRewrite(ForOp op, 1067 PatternRewriter &rewriter) const override { 1068 for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) { 1069 OpOperand &iterOpOperand = std::get<0>(it); 1070 auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>(); 1071 if (!incomingCast || 1072 incomingCast.getSource().getType() == incomingCast.getType()) 1073 continue; 1074 // If the dest type of the cast does not preserve static information in 1075 // the source type. 1076 if (!tensor::preservesStaticInformation( 1077 incomingCast.getDest().getType(), 1078 incomingCast.getSource().getType())) 1079 continue; 1080 if (!std::get<1>(it).hasOneUse()) 1081 continue; 1082 1083 // Create a new ForOp with that iter operand replaced. 1084 rewriter.replaceOp( 1085 op, replaceAndCastForOpIterArg( 1086 rewriter, op, iterOpOperand, incomingCast.getSource(), 1087 [](OpBuilder &b, Location loc, Type type, Value source) { 1088 return b.create<tensor::CastOp>(loc, type, source); 1089 })); 1090 return success(); 1091 } 1092 return failure(); 1093 } 1094 }; 1095 1096 } // namespace 1097 1098 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, 1099 MLIRContext *context) { 1100 results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>( 1101 context); 1102 } 1103 1104 std::optional<APInt> ForOp::getConstantStep() { 1105 IntegerAttr step; 1106 if (matchPattern(getStep(), m_Constant(&step))) 1107 return step.getValue(); 1108 return {}; 1109 } 1110 1111 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() { 1112 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable(); 1113 } 1114 1115 Speculation::Speculatability ForOp::getSpeculatability() { 1116 // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start 1117 // and End. 1118 if (auto constantStep = getConstantStep()) 1119 if (*constantStep == 1) 1120 return Speculation::RecursivelySpeculatable; 1121 1122 // For Step != 1, the loop may not terminate. We can add more smarts here if 1123 // needed. 1124 return Speculation::NotSpeculatable; 1125 } 1126 1127 //===----------------------------------------------------------------------===// 1128 // ForallOp 1129 //===----------------------------------------------------------------------===// 1130 1131 LogicalResult ForallOp::verify() { 1132 unsigned numLoops = getRank(); 1133 // Check number of outputs. 1134 if (getNumResults() != getOutputs().size()) 1135 return emitOpError("produces ") 1136 << getNumResults() << " results, but has only " 1137 << getOutputs().size() << " outputs"; 1138 1139 // Check that the body defines block arguments for thread indices and outputs. 1140 auto *body = getBody(); 1141 if (body->getNumArguments() != numLoops + getOutputs().size()) 1142 return emitOpError("region expects ") << numLoops << " arguments"; 1143 for (int64_t i = 0; i < numLoops; ++i) 1144 if (!body->getArgument(i).getType().isIndex()) 1145 return emitOpError("expects ") 1146 << i << "-th block argument to be an index"; 1147 for (unsigned i = 0; i < getOutputs().size(); ++i) 1148 if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType()) 1149 return emitOpError("type mismatch between ") 1150 << i << "-th output and corresponding block argument"; 1151 if (getMapping().has_value() && !getMapping()->empty()) { 1152 if (static_cast<int64_t>(getMapping()->size()) != numLoops) 1153 return emitOpError() << "mapping attribute size must match op rank"; 1154 for (auto map : getMapping()->getValue()) { 1155 if (!isa<DeviceMappingAttrInterface>(map)) 1156 return emitOpError() 1157 << getMappingAttrName() << " is not device mapping attribute"; 1158 } 1159 } 1160 1161 // Verify mixed static/dynamic control variables. 1162 Operation *op = getOperation(); 1163 if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops, 1164 getStaticLowerBound(), 1165 getDynamicLowerBound()))) 1166 return failure(); 1167 if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops, 1168 getStaticUpperBound(), 1169 getDynamicUpperBound()))) 1170 return failure(); 1171 if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops, 1172 getStaticStep(), getDynamicStep()))) 1173 return failure(); 1174 1175 return success(); 1176 } 1177 1178 void ForallOp::print(OpAsmPrinter &p) { 1179 Operation *op = getOperation(); 1180 p << " (" << getInductionVars(); 1181 if (isNormalized()) { 1182 p << ") in "; 1183 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), 1184 /*valueTypes=*/{}, /*scalables=*/{}, 1185 OpAsmParser::Delimiter::Paren); 1186 } else { 1187 p << ") = "; 1188 printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(), 1189 /*valueTypes=*/{}, /*scalables=*/{}, 1190 OpAsmParser::Delimiter::Paren); 1191 p << " to "; 1192 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), 1193 /*valueTypes=*/{}, /*scalables=*/{}, 1194 OpAsmParser::Delimiter::Paren); 1195 p << " step "; 1196 printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(), 1197 /*valueTypes=*/{}, /*scalables=*/{}, 1198 OpAsmParser::Delimiter::Paren); 1199 } 1200 printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs"); 1201 p << " "; 1202 if (!getRegionOutArgs().empty()) 1203 p << "-> (" << getResultTypes() << ") "; 1204 p.printRegion(getRegion(), 1205 /*printEntryBlockArgs=*/false, 1206 /*printBlockTerminators=*/getNumResults() > 0); 1207 p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(), 1208 getStaticLowerBoundAttrName(), 1209 getStaticUpperBoundAttrName(), 1210 getStaticStepAttrName()}); 1211 } 1212 1213 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) { 1214 OpBuilder b(parser.getContext()); 1215 auto indexType = b.getIndexType(); 1216 1217 // Parse an opening `(` followed by thread index variables followed by `)` 1218 // TODO: when we can refer to such "induction variable"-like handles from the 1219 // declarative assembly format, we can implement the parser as a custom hook. 1220 SmallVector<OpAsmParser::Argument, 4> ivs; 1221 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren)) 1222 return failure(); 1223 1224 DenseI64ArrayAttr staticLbs, staticUbs, staticSteps; 1225 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs, 1226 dynamicSteps; 1227 if (succeeded(parser.parseOptionalKeyword("in"))) { 1228 // Parse upper bounds. 1229 if (parseDynamicIndexList(parser, dynamicUbs, staticUbs, 1230 /*valueTypes=*/nullptr, 1231 OpAsmParser::Delimiter::Paren) || 1232 parser.resolveOperands(dynamicUbs, indexType, result.operands)) 1233 return failure(); 1234 1235 unsigned numLoops = ivs.size(); 1236 staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0)); 1237 staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1)); 1238 } else { 1239 // Parse lower bounds. 1240 if (parser.parseEqual() || 1241 parseDynamicIndexList(parser, dynamicLbs, staticLbs, 1242 /*valueTypes=*/nullptr, 1243 OpAsmParser::Delimiter::Paren) || 1244 1245 parser.resolveOperands(dynamicLbs, indexType, result.operands)) 1246 return failure(); 1247 1248 // Parse upper bounds. 1249 if (parser.parseKeyword("to") || 1250 parseDynamicIndexList(parser, dynamicUbs, staticUbs, 1251 /*valueTypes=*/nullptr, 1252 OpAsmParser::Delimiter::Paren) || 1253 parser.resolveOperands(dynamicUbs, indexType, result.operands)) 1254 return failure(); 1255 1256 // Parse step values. 1257 if (parser.parseKeyword("step") || 1258 parseDynamicIndexList(parser, dynamicSteps, staticSteps, 1259 /*valueTypes=*/nullptr, 1260 OpAsmParser::Delimiter::Paren) || 1261 parser.resolveOperands(dynamicSteps, indexType, result.operands)) 1262 return failure(); 1263 } 1264 1265 // Parse out operands and results. 1266 SmallVector<OpAsmParser::Argument, 4> regionOutArgs; 1267 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands; 1268 SMLoc outOperandsLoc = parser.getCurrentLocation(); 1269 if (succeeded(parser.parseOptionalKeyword("shared_outs"))) { 1270 if (outOperands.size() != result.types.size()) 1271 return parser.emitError(outOperandsLoc, 1272 "mismatch between out operands and types"); 1273 if (parser.parseAssignmentList(regionOutArgs, outOperands) || 1274 parser.parseOptionalArrowTypeList(result.types) || 1275 parser.resolveOperands(outOperands, result.types, outOperandsLoc, 1276 result.operands)) 1277 return failure(); 1278 } 1279 1280 // Parse region. 1281 SmallVector<OpAsmParser::Argument, 4> regionArgs; 1282 std::unique_ptr<Region> region = std::make_unique<Region>(); 1283 for (auto &iv : ivs) { 1284 iv.type = b.getIndexType(); 1285 regionArgs.push_back(iv); 1286 } 1287 for (const auto &it : llvm::enumerate(regionOutArgs)) { 1288 auto &out = it.value(); 1289 out.type = result.types[it.index()]; 1290 regionArgs.push_back(out); 1291 } 1292 if (parser.parseRegion(*region, regionArgs)) 1293 return failure(); 1294 1295 // Ensure terminator and move region. 1296 ForallOp::ensureTerminator(*region, b, result.location); 1297 result.addRegion(std::move(region)); 1298 1299 // Parse the optional attribute list. 1300 if (parser.parseOptionalAttrDict(result.attributes)) 1301 return failure(); 1302 1303 result.addAttribute("staticLowerBound", staticLbs); 1304 result.addAttribute("staticUpperBound", staticUbs); 1305 result.addAttribute("staticStep", staticSteps); 1306 result.addAttribute("operandSegmentSizes", 1307 parser.getBuilder().getDenseI32ArrayAttr( 1308 {static_cast<int32_t>(dynamicLbs.size()), 1309 static_cast<int32_t>(dynamicUbs.size()), 1310 static_cast<int32_t>(dynamicSteps.size()), 1311 static_cast<int32_t>(outOperands.size())})); 1312 return success(); 1313 } 1314 1315 // Builder that takes loop bounds. 1316 void ForallOp::build( 1317 mlir::OpBuilder &b, mlir::OperationState &result, 1318 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, 1319 ArrayRef<OpFoldResult> steps, ValueRange outputs, 1320 std::optional<ArrayAttr> mapping, 1321 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { 1322 SmallVector<int64_t> staticLbs, staticUbs, staticSteps; 1323 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps; 1324 dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs); 1325 dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs); 1326 dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps); 1327 1328 result.addOperands(dynamicLbs); 1329 result.addOperands(dynamicUbs); 1330 result.addOperands(dynamicSteps); 1331 result.addOperands(outputs); 1332 result.addTypes(TypeRange(outputs)); 1333 1334 result.addAttribute(getStaticLowerBoundAttrName(result.name), 1335 b.getDenseI64ArrayAttr(staticLbs)); 1336 result.addAttribute(getStaticUpperBoundAttrName(result.name), 1337 b.getDenseI64ArrayAttr(staticUbs)); 1338 result.addAttribute(getStaticStepAttrName(result.name), 1339 b.getDenseI64ArrayAttr(staticSteps)); 1340 result.addAttribute( 1341 "operandSegmentSizes", 1342 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()), 1343 static_cast<int32_t>(dynamicUbs.size()), 1344 static_cast<int32_t>(dynamicSteps.size()), 1345 static_cast<int32_t>(outputs.size())})); 1346 if (mapping.has_value()) { 1347 result.addAttribute(ForallOp::getMappingAttrName(result.name), 1348 mapping.value()); 1349 } 1350 1351 Region *bodyRegion = result.addRegion(); 1352 OpBuilder::InsertionGuard g(b); 1353 b.createBlock(bodyRegion); 1354 Block &bodyBlock = bodyRegion->front(); 1355 1356 // Add block arguments for indices and outputs. 1357 bodyBlock.addArguments( 1358 SmallVector<Type>(lbs.size(), b.getIndexType()), 1359 SmallVector<Location>(staticLbs.size(), result.location)); 1360 bodyBlock.addArguments( 1361 TypeRange(outputs), 1362 SmallVector<Location>(outputs.size(), result.location)); 1363 1364 b.setInsertionPointToStart(&bodyBlock); 1365 if (!bodyBuilderFn) { 1366 ForallOp::ensureTerminator(*bodyRegion, b, result.location); 1367 return; 1368 } 1369 bodyBuilderFn(b, result.location, bodyBlock.getArguments()); 1370 } 1371 1372 // Builder that takes loop bounds. 1373 void ForallOp::build( 1374 mlir::OpBuilder &b, mlir::OperationState &result, 1375 ArrayRef<OpFoldResult> ubs, ValueRange outputs, 1376 std::optional<ArrayAttr> mapping, 1377 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { 1378 unsigned numLoops = ubs.size(); 1379 SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0)); 1380 SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1)); 1381 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn); 1382 } 1383 1384 // Checks if the lbs are zeros and steps are ones. 1385 bool ForallOp::isNormalized() { 1386 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) { 1387 return llvm::all_of(results, [&](OpFoldResult ofr) { 1388 auto intValue = getConstantIntValue(ofr); 1389 return intValue.has_value() && intValue == val; 1390 }); 1391 }; 1392 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1); 1393 } 1394 1395 // The ensureTerminator method generated by SingleBlockImplicitTerminator is 1396 // unaware of the fact that our terminator also needs a region to be 1397 // well-formed. We override it here to ensure that we do the right thing. 1398 void ForallOp::ensureTerminator(Region ®ion, OpBuilder &builder, 1399 Location loc) { 1400 OpTrait::SingleBlockImplicitTerminator<InParallelOp>::Impl< 1401 ForallOp>::ensureTerminator(region, builder, loc); 1402 auto terminator = 1403 llvm::dyn_cast<InParallelOp>(region.front().getTerminator()); 1404 if (terminator.getRegion().empty()) 1405 builder.createBlock(&terminator.getRegion()); 1406 } 1407 1408 InParallelOp ForallOp::getTerminator() { 1409 return cast<InParallelOp>(getBody()->getTerminator()); 1410 } 1411 1412 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) { 1413 SmallVector<Operation *> storeOps; 1414 InParallelOp inParallelOp = getTerminator(); 1415 for (Operation &yieldOp : inParallelOp.getYieldingOps()) { 1416 if (auto parallelInsertSliceOp = 1417 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp); 1418 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) { 1419 storeOps.push_back(parallelInsertSliceOp); 1420 } 1421 } 1422 return storeOps; 1423 } 1424 1425 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() { 1426 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())}; 1427 } 1428 1429 // Get lower bounds as OpFoldResult. 1430 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() { 1431 Builder b(getOperation()->getContext()); 1432 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b); 1433 } 1434 1435 // Get upper bounds as OpFoldResult. 1436 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() { 1437 Builder b(getOperation()->getContext()); 1438 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b); 1439 } 1440 1441 // Get steps as OpFoldResult. 1442 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() { 1443 Builder b(getOperation()->getContext()); 1444 return getMixedValues(getStaticStep(), getDynamicStep(), b); 1445 } 1446 1447 ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) { 1448 auto tidxArg = llvm::dyn_cast<BlockArgument>(val); 1449 if (!tidxArg) 1450 return ForallOp(); 1451 assert(tidxArg.getOwner() && "unlinked block argument"); 1452 auto *containingOp = tidxArg.getOwner()->getParentOp(); 1453 return dyn_cast<ForallOp>(containingOp); 1454 } 1455 1456 namespace { 1457 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t). 1458 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> { 1459 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 1460 1461 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 1462 PatternRewriter &rewriter) const final { 1463 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>(); 1464 if (!forallOp) 1465 return failure(); 1466 Value sharedOut = 1467 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource())) 1468 ->get(); 1469 rewriter.modifyOpInPlace( 1470 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); }); 1471 return success(); 1472 } 1473 }; 1474 1475 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> { 1476 public: 1477 using OpRewritePattern<ForallOp>::OpRewritePattern; 1478 1479 LogicalResult matchAndRewrite(ForallOp op, 1480 PatternRewriter &rewriter) const override { 1481 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound()); 1482 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound()); 1483 SmallVector<OpFoldResult> mixedStep(op.getMixedStep()); 1484 if (failed(foldDynamicIndexList(mixedLowerBound)) && 1485 failed(foldDynamicIndexList(mixedUpperBound)) && 1486 failed(foldDynamicIndexList(mixedStep))) 1487 return failure(); 1488 1489 rewriter.modifyOpInPlace(op, [&]() { 1490 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep; 1491 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep; 1492 dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound, 1493 staticLowerBound); 1494 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound); 1495 op.setStaticLowerBound(staticLowerBound); 1496 1497 dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound, 1498 staticUpperBound); 1499 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound); 1500 op.setStaticUpperBound(staticUpperBound); 1501 1502 dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep); 1503 op.getDynamicStepMutable().assign(dynamicStep); 1504 op.setStaticStep(staticStep); 1505 1506 op->setAttr(ForallOp::getOperandSegmentSizeAttr(), 1507 rewriter.getDenseI32ArrayAttr( 1508 {static_cast<int32_t>(dynamicLowerBound.size()), 1509 static_cast<int32_t>(dynamicUpperBound.size()), 1510 static_cast<int32_t>(dynamicStep.size()), 1511 static_cast<int32_t>(op.getNumResults())})); 1512 }); 1513 return success(); 1514 } 1515 }; 1516 1517 /// The following canonicalization pattern folds the iter arguments of 1518 /// scf.forall op if :- 1519 /// 1. The corresponding result has zero uses. 1520 /// 2. The iter argument is NOT being modified within the loop body. 1521 /// uses. 1522 /// 1523 /// Example of first case :- 1524 /// INPUT: 1525 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c) 1526 /// { 1527 /// ... 1528 /// <SOME USE OF %arg0> 1529 /// <SOME USE OF %arg1> 1530 /// <SOME USE OF %arg2> 1531 /// ... 1532 /// scf.forall.in_parallel { 1533 /// <STORE OP WITH DESTINATION %arg1> 1534 /// <STORE OP WITH DESTINATION %arg0> 1535 /// <STORE OP WITH DESTINATION %arg2> 1536 /// } 1537 /// } 1538 /// return %res#1 1539 /// 1540 /// OUTPUT: 1541 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b) 1542 /// { 1543 /// ... 1544 /// <SOME USE OF %a> 1545 /// <SOME USE OF %new_arg0> 1546 /// <SOME USE OF %c> 1547 /// ... 1548 /// scf.forall.in_parallel { 1549 /// <STORE OP WITH DESTINATION %new_arg0> 1550 /// } 1551 /// } 1552 /// return %res 1553 /// 1554 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the 1555 /// scf.forall is replaced by their corresponding operands. 1556 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body 1557 /// of the scf.forall besides within scf.forall.in_parallel terminator, 1558 /// this canonicalization remains valid. For more details, please refer 1559 /// to : 1560 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124 1561 /// 3. TODO(avarma): Generalize it for other store ops. Currently it 1562 /// handles tensor.parallel_insert_slice ops only. 1563 /// 1564 /// Example of second case :- 1565 /// INPUT: 1566 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b) 1567 /// { 1568 /// ... 1569 /// <SOME USE OF %arg0> 1570 /// <SOME USE OF %arg1> 1571 /// ... 1572 /// scf.forall.in_parallel { 1573 /// <STORE OP WITH DESTINATION %arg1> 1574 /// } 1575 /// } 1576 /// return %res#0, %res#1 1577 /// 1578 /// OUTPUT: 1579 /// %res = scf.forall ... shared_outs(%new_arg0 = %b) 1580 /// { 1581 /// ... 1582 /// <SOME USE OF %a> 1583 /// <SOME USE OF %new_arg0> 1584 /// ... 1585 /// scf.forall.in_parallel { 1586 /// <STORE OP WITH DESTINATION %new_arg0> 1587 /// } 1588 /// } 1589 /// return %a, %res 1590 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> { 1591 using OpRewritePattern<ForallOp>::OpRewritePattern; 1592 1593 LogicalResult matchAndRewrite(ForallOp forallOp, 1594 PatternRewriter &rewriter) const final { 1595 // Step 1: For a given i-th result of scf.forall, check the following :- 1596 // a. If it has any use. 1597 // b. If the corresponding iter argument is being modified within 1598 // the loop, i.e. has at least one store op with the iter arg as 1599 // its destination operand. For this we use 1600 // ForallOp::getCombiningOps(iter_arg). 1601 // 1602 // Based on the check we maintain the following :- 1603 // a. `resultToDelete` - i-th result of scf.forall that'll be 1604 // deleted. 1605 // b. `resultToReplace` - i-th result of the old scf.forall 1606 // whose uses will be replaced by the new scf.forall. 1607 // c. `newOuts` - the shared_outs' operand of the new scf.forall 1608 // corresponding to the i-th result with at least one use. 1609 SetVector<OpResult> resultToDelete; 1610 SmallVector<Value> resultToReplace; 1611 SmallVector<Value> newOuts; 1612 for (OpResult result : forallOp.getResults()) { 1613 OpOperand *opOperand = forallOp.getTiedOpOperand(result); 1614 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand); 1615 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) { 1616 resultToDelete.insert(result); 1617 } else { 1618 resultToReplace.push_back(result); 1619 newOuts.push_back(opOperand->get()); 1620 } 1621 } 1622 1623 // Return early if all results of scf.forall have at least one use and being 1624 // modified within the loop. 1625 if (resultToDelete.empty()) 1626 return failure(); 1627 1628 // Step 2: For the the i-th result, do the following :- 1629 // a. Fetch the corresponding BlockArgument. 1630 // b. Look for store ops (currently tensor.parallel_insert_slice) 1631 // with the BlockArgument as its destination operand. 1632 // c. Remove the operations fetched in b. 1633 for (OpResult result : resultToDelete) { 1634 OpOperand *opOperand = forallOp.getTiedOpOperand(result); 1635 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand); 1636 SmallVector<Operation *> combiningOps = 1637 forallOp.getCombiningOps(blockArg); 1638 for (Operation *combiningOp : combiningOps) 1639 rewriter.eraseOp(combiningOp); 1640 } 1641 1642 // Step 3. Create a new scf.forall op with the new shared_outs' operands 1643 // fetched earlier 1644 auto newForallOp = rewriter.create<scf::ForallOp>( 1645 forallOp.getLoc(), forallOp.getMixedLowerBound(), 1646 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts, 1647 forallOp.getMapping(), 1648 /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); 1649 1650 // Step 4. Merge the block of the old scf.forall into the newly created 1651 // scf.forall using the new set of arguments. 1652 Block *loopBody = forallOp.getBody(); 1653 Block *newLoopBody = newForallOp.getBody(); 1654 ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments(); 1655 // Form initial new bbArg list with just the control operands of the new 1656 // scf.forall op. 1657 SmallVector<Value> newBlockArgs = 1658 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()), 1659 [](BlockArgument b) -> Value { return b; }); 1660 Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs(); 1661 unsigned index = 0; 1662 // Take the new corresponding bbArg if the old bbArg was used as a 1663 // destination in the in_parallel op. For all other bbArgs, use the 1664 // corresponding init_arg from the old scf.forall op. 1665 for (OpResult result : forallOp.getResults()) { 1666 if (resultToDelete.count(result)) { 1667 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get()); 1668 } else { 1669 newBlockArgs.push_back(newSharedOutsArgs[index++]); 1670 } 1671 } 1672 rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs); 1673 1674 // Step 5. Replace the uses of result of old scf.forall with that of the new 1675 // scf.forall. 1676 for (auto &&[oldResult, newResult] : 1677 llvm::zip(resultToReplace, newForallOp->getResults())) 1678 rewriter.replaceAllUsesWith(oldResult, newResult); 1679 1680 // Step 6. Replace the uses of those values that either has no use or are 1681 // not being modified within the loop with the corresponding 1682 // OpOperand. 1683 for (OpResult oldResult : resultToDelete) 1684 rewriter.replaceAllUsesWith(oldResult, 1685 forallOp.getTiedOpOperand(oldResult)->get()); 1686 return success(); 1687 } 1688 }; 1689 1690 struct ForallOpSingleOrZeroIterationDimsFolder 1691 : public OpRewritePattern<ForallOp> { 1692 using OpRewritePattern<ForallOp>::OpRewritePattern; 1693 1694 LogicalResult matchAndRewrite(ForallOp op, 1695 PatternRewriter &rewriter) const override { 1696 // Do not fold dimensions if they are mapped to processing units. 1697 if (op.getMapping().has_value() && !op.getMapping()->empty()) 1698 return failure(); 1699 Location loc = op.getLoc(); 1700 1701 // Compute new loop bounds that omit all single-iteration loop dimensions. 1702 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds, 1703 newMixedSteps; 1704 IRMapping mapping; 1705 for (auto [lb, ub, step, iv] : 1706 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), 1707 op.getMixedStep(), op.getInductionVars())) { 1708 auto numIterations = constantTripCount(lb, ub, step); 1709 if (numIterations.has_value()) { 1710 // Remove the loop if it performs zero iterations. 1711 if (*numIterations == 0) { 1712 rewriter.replaceOp(op, op.getOutputs()); 1713 return success(); 1714 } 1715 // Replace the loop induction variable by the lower bound if the loop 1716 // performs a single iteration. Otherwise, copy the loop bounds. 1717 if (*numIterations == 1) { 1718 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb)); 1719 continue; 1720 } 1721 } 1722 newMixedLowerBounds.push_back(lb); 1723 newMixedUpperBounds.push_back(ub); 1724 newMixedSteps.push_back(step); 1725 } 1726 1727 // All of the loop dimensions perform a single iteration. Inline loop body. 1728 if (newMixedLowerBounds.empty()) { 1729 promote(rewriter, op); 1730 return success(); 1731 } 1732 1733 // Exit if none of the loop dimensions perform a single iteration. 1734 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) { 1735 return rewriter.notifyMatchFailure( 1736 op, "no dimensions have 0 or 1 iterations"); 1737 } 1738 1739 // Replace the loop by a lower-dimensional loop. 1740 ForallOp newOp; 1741 newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds, 1742 newMixedUpperBounds, newMixedSteps, 1743 op.getOutputs(), std::nullopt, nullptr); 1744 newOp.getBodyRegion().getBlocks().clear(); 1745 // The new loop needs to keep all attributes from the old one, except for 1746 // "operandSegmentSizes" and static loop bound attributes which capture 1747 // the outdated information of the old iteration domain. 1748 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(), 1749 newOp.getStaticLowerBoundAttrName(), 1750 newOp.getStaticUpperBoundAttrName(), 1751 newOp.getStaticStepAttrName()}; 1752 for (const auto &namedAttr : op->getAttrs()) { 1753 if (llvm::is_contained(elidedAttrs, namedAttr.getName())) 1754 continue; 1755 rewriter.modifyOpInPlace(newOp, [&]() { 1756 newOp->setAttr(namedAttr.getName(), namedAttr.getValue()); 1757 }); 1758 } 1759 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(), 1760 newOp.getRegion().begin(), mapping); 1761 rewriter.replaceOp(op, newOp.getResults()); 1762 return success(); 1763 } 1764 }; 1765 1766 /// Replace all induction vars with a single trip count with their lower bound. 1767 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> { 1768 using OpRewritePattern<ForallOp>::OpRewritePattern; 1769 1770 LogicalResult matchAndRewrite(ForallOp op, 1771 PatternRewriter &rewriter) const override { 1772 Location loc = op.getLoc(); 1773 bool changed = false; 1774 for (auto [lb, ub, step, iv] : 1775 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), 1776 op.getMixedStep(), op.getInductionVars())) { 1777 if (iv.getUses().begin() == iv.getUses().end()) 1778 continue; 1779 auto numIterations = constantTripCount(lb, ub, step); 1780 if (!numIterations.has_value() || numIterations.value() != 1) { 1781 continue; 1782 } 1783 rewriter.replaceAllUsesWith( 1784 iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb)); 1785 changed = true; 1786 } 1787 return success(changed); 1788 } 1789 }; 1790 1791 struct FoldTensorCastOfOutputIntoForallOp 1792 : public OpRewritePattern<scf::ForallOp> { 1793 using OpRewritePattern<scf::ForallOp>::OpRewritePattern; 1794 1795 struct TypeCast { 1796 Type srcType; 1797 Type dstType; 1798 }; 1799 1800 LogicalResult matchAndRewrite(scf::ForallOp forallOp, 1801 PatternRewriter &rewriter) const final { 1802 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers; 1803 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs(); 1804 for (auto en : llvm::enumerate(newOutputTensors)) { 1805 auto castOp = en.value().getDefiningOp<tensor::CastOp>(); 1806 if (!castOp) 1807 continue; 1808 1809 // Only casts that that preserve static information, i.e. will make the 1810 // loop result type "more" static than before, will be folded. 1811 if (!tensor::preservesStaticInformation(castOp.getDest().getType(), 1812 castOp.getSource().getType())) { 1813 continue; 1814 } 1815 1816 tensorCastProducers[en.index()] = 1817 TypeCast{castOp.getSource().getType(), castOp.getType()}; 1818 newOutputTensors[en.index()] = castOp.getSource(); 1819 } 1820 1821 if (tensorCastProducers.empty()) 1822 return failure(); 1823 1824 // Create new loop. 1825 Location loc = forallOp.getLoc(); 1826 auto newForallOp = rewriter.create<ForallOp>( 1827 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), 1828 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(), 1829 [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) { 1830 auto castBlockArgs = 1831 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults())); 1832 for (auto [index, cast] : tensorCastProducers) { 1833 Value &oldTypeBBArg = castBlockArgs[index]; 1834 oldTypeBBArg = nestedBuilder.create<tensor::CastOp>( 1835 nestedLoc, cast.dstType, oldTypeBBArg); 1836 } 1837 1838 // Move old body into new parallel loop. 1839 SmallVector<Value> ivsBlockArgs = 1840 llvm::to_vector(bbArgs.take_front(forallOp.getRank())); 1841 ivsBlockArgs.append(castBlockArgs); 1842 rewriter.mergeBlocks(forallOp.getBody(), 1843 bbArgs.front().getParentBlock(), ivsBlockArgs); 1844 }); 1845 1846 // After `mergeBlocks` happened, the destinations in the terminator were 1847 // mapped to the tensor.cast old-typed results of the output bbArgs. The 1848 // destination have to be updated to point to the output bbArgs directly. 1849 auto terminator = newForallOp.getTerminator(); 1850 for (auto [yieldingOp, outputBlockArg] : llvm::zip( 1851 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) { 1852 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp); 1853 insertSliceOp.getDestMutable().assign(outputBlockArg); 1854 } 1855 1856 // Cast results back to the original types. 1857 rewriter.setInsertionPointAfter(newForallOp); 1858 SmallVector<Value> castResults = newForallOp.getResults(); 1859 for (auto &item : tensorCastProducers) { 1860 Value &oldTypeResult = castResults[item.first]; 1861 oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType, 1862 oldTypeResult); 1863 } 1864 rewriter.replaceOp(forallOp, castResults); 1865 return success(); 1866 } 1867 }; 1868 1869 } // namespace 1870 1871 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results, 1872 MLIRContext *context) { 1873 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp, 1874 ForallOpControlOperandsFolder, ForallOpIterArgsFolder, 1875 ForallOpSingleOrZeroIterationDimsFolder, 1876 ForallOpReplaceConstantInductionVar>(context); 1877 } 1878 1879 /// Given the region at `index`, or the parent operation if `index` is None, 1880 /// return the successor regions. These are the regions that may be selected 1881 /// during the flow of control. `operands` is a set of optional attributes that 1882 /// correspond to a constant value for each operand, or null if that operand is 1883 /// not a constant. 1884 void ForallOp::getSuccessorRegions(RegionBranchPoint point, 1885 SmallVectorImpl<RegionSuccessor> ®ions) { 1886 // Both the operation itself and the region may be branching into the body or 1887 // back into the operation itself. It is possible for loop not to enter the 1888 // body. 1889 regions.push_back(RegionSuccessor(&getRegion())); 1890 regions.push_back(RegionSuccessor()); 1891 } 1892 1893 //===----------------------------------------------------------------------===// 1894 // InParallelOp 1895 //===----------------------------------------------------------------------===// 1896 1897 // Build a InParallelOp with mixed static and dynamic entries. 1898 void InParallelOp::build(OpBuilder &b, OperationState &result) { 1899 OpBuilder::InsertionGuard g(b); 1900 Region *bodyRegion = result.addRegion(); 1901 b.createBlock(bodyRegion); 1902 } 1903 1904 LogicalResult InParallelOp::verify() { 1905 scf::ForallOp forallOp = 1906 dyn_cast<scf::ForallOp>(getOperation()->getParentOp()); 1907 if (!forallOp) 1908 return this->emitOpError("expected forall op parent"); 1909 1910 // TODO: InParallelOpInterface. 1911 for (Operation &op : getRegion().front().getOperations()) { 1912 if (!isa<tensor::ParallelInsertSliceOp>(op)) { 1913 return this->emitOpError("expected only ") 1914 << tensor::ParallelInsertSliceOp::getOperationName() << " ops"; 1915 } 1916 1917 // Verify that inserts are into out block arguments. 1918 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest(); 1919 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs(); 1920 if (!llvm::is_contained(regionOutArgs, dest)) 1921 return op.emitOpError("may only insert into an output block argument"); 1922 } 1923 return success(); 1924 } 1925 1926 void InParallelOp::print(OpAsmPrinter &p) { 1927 p << " "; 1928 p.printRegion(getRegion(), 1929 /*printEntryBlockArgs=*/false, 1930 /*printBlockTerminators=*/false); 1931 p.printOptionalAttrDict(getOperation()->getAttrs()); 1932 } 1933 1934 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) { 1935 auto &builder = parser.getBuilder(); 1936 1937 SmallVector<OpAsmParser::Argument, 8> regionOperands; 1938 std::unique_ptr<Region> region = std::make_unique<Region>(); 1939 if (parser.parseRegion(*region, regionOperands)) 1940 return failure(); 1941 1942 if (region->empty()) 1943 OpBuilder(builder.getContext()).createBlock(region.get()); 1944 result.addRegion(std::move(region)); 1945 1946 // Parse the optional attribute list. 1947 if (parser.parseOptionalAttrDict(result.attributes)) 1948 return failure(); 1949 return success(); 1950 } 1951 1952 OpResult InParallelOp::getParentResult(int64_t idx) { 1953 return getOperation()->getParentOp()->getResult(idx); 1954 } 1955 1956 SmallVector<BlockArgument> InParallelOp::getDests() { 1957 return llvm::to_vector<4>( 1958 llvm::map_range(getYieldingOps(), [](Operation &op) { 1959 // Add new ops here as needed. 1960 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op); 1961 return llvm::cast<BlockArgument>(insertSliceOp.getDest()); 1962 })); 1963 } 1964 1965 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() { 1966 return getRegion().front().getOperations(); 1967 } 1968 1969 //===----------------------------------------------------------------------===// 1970 // IfOp 1971 //===----------------------------------------------------------------------===// 1972 1973 bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) { 1974 assert(a && "expected non-empty operation"); 1975 assert(b && "expected non-empty operation"); 1976 1977 IfOp ifOp = a->getParentOfType<IfOp>(); 1978 while (ifOp) { 1979 // Check if b is inside ifOp. (We already know that a is.) 1980 if (ifOp->isProperAncestor(b)) 1981 // b is contained in ifOp. a and b are in mutually exclusive branches if 1982 // they are in different blocks of ifOp. 1983 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) != 1984 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b)); 1985 // Check next enclosing IfOp. 1986 ifOp = ifOp->getParentOfType<IfOp>(); 1987 } 1988 1989 // Could not find a common IfOp among a's and b's ancestors. 1990 return false; 1991 } 1992 1993 LogicalResult 1994 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1995 IfOp::Adaptor adaptor, 1996 SmallVectorImpl<Type> &inferredReturnTypes) { 1997 if (adaptor.getRegions().empty()) 1998 return failure(); 1999 Region *r = &adaptor.getThenRegion(); 2000 if (r->empty()) 2001 return failure(); 2002 Block &b = r->front(); 2003 if (b.empty()) 2004 return failure(); 2005 auto yieldOp = llvm::dyn_cast<YieldOp>(b.back()); 2006 if (!yieldOp) 2007 return failure(); 2008 TypeRange types = yieldOp.getOperandTypes(); 2009 inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(), 2010 types.end()); 2011 return success(); 2012 } 2013 2014 void IfOp::build(OpBuilder &builder, OperationState &result, 2015 TypeRange resultTypes, Value cond) { 2016 return build(builder, result, resultTypes, cond, /*addThenBlock=*/false, 2017 /*addElseBlock=*/false); 2018 } 2019 2020 void IfOp::build(OpBuilder &builder, OperationState &result, 2021 TypeRange resultTypes, Value cond, bool addThenBlock, 2022 bool addElseBlock) { 2023 assert((!addElseBlock || addThenBlock) && 2024 "must not create else block w/o then block"); 2025 result.addTypes(resultTypes); 2026 result.addOperands(cond); 2027 2028 // Add regions and blocks. 2029 OpBuilder::InsertionGuard guard(builder); 2030 Region *thenRegion = result.addRegion(); 2031 if (addThenBlock) 2032 builder.createBlock(thenRegion); 2033 Region *elseRegion = result.addRegion(); 2034 if (addElseBlock) 2035 builder.createBlock(elseRegion); 2036 } 2037 2038 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, 2039 bool withElseRegion) { 2040 build(builder, result, TypeRange{}, cond, withElseRegion); 2041 } 2042 2043 void IfOp::build(OpBuilder &builder, OperationState &result, 2044 TypeRange resultTypes, Value cond, bool withElseRegion) { 2045 result.addTypes(resultTypes); 2046 result.addOperands(cond); 2047 2048 // Build then region. 2049 OpBuilder::InsertionGuard guard(builder); 2050 Region *thenRegion = result.addRegion(); 2051 builder.createBlock(thenRegion); 2052 if (resultTypes.empty()) 2053 IfOp::ensureTerminator(*thenRegion, builder, result.location); 2054 2055 // Build else region. 2056 Region *elseRegion = result.addRegion(); 2057 if (withElseRegion) { 2058 builder.createBlock(elseRegion); 2059 if (resultTypes.empty()) 2060 IfOp::ensureTerminator(*elseRegion, builder, result.location); 2061 } 2062 } 2063 2064 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, 2065 function_ref<void(OpBuilder &, Location)> thenBuilder, 2066 function_ref<void(OpBuilder &, Location)> elseBuilder) { 2067 assert(thenBuilder && "the builder callback for 'then' must be present"); 2068 result.addOperands(cond); 2069 2070 // Build then region. 2071 OpBuilder::InsertionGuard guard(builder); 2072 Region *thenRegion = result.addRegion(); 2073 builder.createBlock(thenRegion); 2074 thenBuilder(builder, result.location); 2075 2076 // Build else region. 2077 Region *elseRegion = result.addRegion(); 2078 if (elseBuilder) { 2079 builder.createBlock(elseRegion); 2080 elseBuilder(builder, result.location); 2081 } 2082 2083 // Infer result types. 2084 SmallVector<Type> inferredReturnTypes; 2085 MLIRContext *ctx = builder.getContext(); 2086 auto attrDict = DictionaryAttr::get(ctx, result.attributes); 2087 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict, 2088 /*properties=*/nullptr, result.regions, 2089 inferredReturnTypes))) { 2090 result.addTypes(inferredReturnTypes); 2091 } 2092 } 2093 2094 LogicalResult IfOp::verify() { 2095 if (getNumResults() != 0 && getElseRegion().empty()) 2096 return emitOpError("must have an else block if defining values"); 2097 return success(); 2098 } 2099 2100 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { 2101 // Create the regions for 'then'. 2102 result.regions.reserve(2); 2103 Region *thenRegion = result.addRegion(); 2104 Region *elseRegion = result.addRegion(); 2105 2106 auto &builder = parser.getBuilder(); 2107 OpAsmParser::UnresolvedOperand cond; 2108 Type i1Type = builder.getIntegerType(1); 2109 if (parser.parseOperand(cond) || 2110 parser.resolveOperand(cond, i1Type, result.operands)) 2111 return failure(); 2112 // Parse optional results type list. 2113 if (parser.parseOptionalArrowTypeList(result.types)) 2114 return failure(); 2115 // Parse the 'then' region. 2116 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) 2117 return failure(); 2118 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); 2119 2120 // If we find an 'else' keyword then parse the 'else' region. 2121 if (!parser.parseOptionalKeyword("else")) { 2122 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) 2123 return failure(); 2124 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); 2125 } 2126 2127 // Parse the optional attribute list. 2128 if (parser.parseOptionalAttrDict(result.attributes)) 2129 return failure(); 2130 return success(); 2131 } 2132 2133 void IfOp::print(OpAsmPrinter &p) { 2134 bool printBlockTerminators = false; 2135 2136 p << " " << getCondition(); 2137 if (!getResults().empty()) { 2138 p << " -> (" << getResultTypes() << ")"; 2139 // Print yield explicitly if the op defines values. 2140 printBlockTerminators = true; 2141 } 2142 p << ' '; 2143 p.printRegion(getThenRegion(), 2144 /*printEntryBlockArgs=*/false, 2145 /*printBlockTerminators=*/printBlockTerminators); 2146 2147 // Print the 'else' regions if it exists and has a block. 2148 auto &elseRegion = getElseRegion(); 2149 if (!elseRegion.empty()) { 2150 p << " else "; 2151 p.printRegion(elseRegion, 2152 /*printEntryBlockArgs=*/false, 2153 /*printBlockTerminators=*/printBlockTerminators); 2154 } 2155 2156 p.printOptionalAttrDict((*this)->getAttrs()); 2157 } 2158 2159 void IfOp::getSuccessorRegions(RegionBranchPoint point, 2160 SmallVectorImpl<RegionSuccessor> ®ions) { 2161 // The `then` and the `else` region branch back to the parent operation. 2162 if (!point.isParent()) { 2163 regions.push_back(RegionSuccessor(getResults())); 2164 return; 2165 } 2166 2167 regions.push_back(RegionSuccessor(&getThenRegion())); 2168 2169 // Don't consider the else region if it is empty. 2170 Region *elseRegion = &this->getElseRegion(); 2171 if (elseRegion->empty()) 2172 regions.push_back(RegionSuccessor()); 2173 else 2174 regions.push_back(RegionSuccessor(elseRegion)); 2175 } 2176 2177 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, 2178 SmallVectorImpl<RegionSuccessor> ®ions) { 2179 FoldAdaptor adaptor(operands, *this); 2180 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition()); 2181 if (!boolAttr || boolAttr.getValue()) 2182 regions.emplace_back(&getThenRegion()); 2183 2184 // If the else region is empty, execution continues after the parent op. 2185 if (!boolAttr || !boolAttr.getValue()) { 2186 if (!getElseRegion().empty()) 2187 regions.emplace_back(&getElseRegion()); 2188 else 2189 regions.emplace_back(getResults()); 2190 } 2191 } 2192 2193 LogicalResult IfOp::fold(FoldAdaptor adaptor, 2194 SmallVectorImpl<OpFoldResult> &results) { 2195 // if (!c) then A() else B() -> if c then B() else A() 2196 if (getElseRegion().empty()) 2197 return failure(); 2198 2199 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>(); 2200 if (!xorStmt) 2201 return failure(); 2202 2203 if (!matchPattern(xorStmt.getRhs(), m_One())) 2204 return failure(); 2205 2206 getConditionMutable().assign(xorStmt.getLhs()); 2207 Block *thenBlock = &getThenRegion().front(); 2208 // It would be nicer to use iplist::swap, but that has no implemented 2209 // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224 2210 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(), 2211 getElseRegion().getBlocks()); 2212 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(), 2213 getThenRegion().getBlocks(), thenBlock); 2214 return success(); 2215 } 2216 2217 void IfOp::getRegionInvocationBounds( 2218 ArrayRef<Attribute> operands, 2219 SmallVectorImpl<InvocationBounds> &invocationBounds) { 2220 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) { 2221 // If the condition is known, then one region is known to be executed once 2222 // and the other zero times. 2223 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); 2224 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1); 2225 } else { 2226 // Non-constant condition. Each region may be executed 0 or 1 times. 2227 invocationBounds.assign(2, {0, 1}); 2228 } 2229 } 2230 2231 namespace { 2232 // Pattern to remove unused IfOp results. 2233 struct RemoveUnusedResults : public OpRewritePattern<IfOp> { 2234 using OpRewritePattern<IfOp>::OpRewritePattern; 2235 2236 void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults, 2237 PatternRewriter &rewriter) const { 2238 // Move all operations to the destination block. 2239 rewriter.mergeBlocks(source, dest); 2240 // Replace the yield op by one that returns only the used values. 2241 auto yieldOp = cast<scf::YieldOp>(dest->getTerminator()); 2242 SmallVector<Value, 4> usedOperands; 2243 llvm::transform(usedResults, std::back_inserter(usedOperands), 2244 [&](OpResult result) { 2245 return yieldOp.getOperand(result.getResultNumber()); 2246 }); 2247 rewriter.modifyOpInPlace(yieldOp, 2248 [&]() { yieldOp->setOperands(usedOperands); }); 2249 } 2250 2251 LogicalResult matchAndRewrite(IfOp op, 2252 PatternRewriter &rewriter) const override { 2253 // Compute the list of used results. 2254 SmallVector<OpResult, 4> usedResults; 2255 llvm::copy_if(op.getResults(), std::back_inserter(usedResults), 2256 [](OpResult result) { return !result.use_empty(); }); 2257 2258 // Replace the operation if only a subset of its results have uses. 2259 if (usedResults.size() == op.getNumResults()) 2260 return failure(); 2261 2262 // Compute the result types of the replacement operation. 2263 SmallVector<Type, 4> newTypes; 2264 llvm::transform(usedResults, std::back_inserter(newTypes), 2265 [](OpResult result) { return result.getType(); }); 2266 2267 // Create a replacement operation with empty then and else regions. 2268 auto newOp = 2269 rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition()); 2270 rewriter.createBlock(&newOp.getThenRegion()); 2271 rewriter.createBlock(&newOp.getElseRegion()); 2272 2273 // Move the bodies and replace the terminators (note there is a then and 2274 // an else region since the operation returns results). 2275 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter); 2276 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter); 2277 2278 // Replace the operation by the new one. 2279 SmallVector<Value, 4> repResults(op.getNumResults()); 2280 for (const auto &en : llvm::enumerate(usedResults)) 2281 repResults[en.value().getResultNumber()] = newOp.getResult(en.index()); 2282 rewriter.replaceOp(op, repResults); 2283 return success(); 2284 } 2285 }; 2286 2287 struct RemoveStaticCondition : public OpRewritePattern<IfOp> { 2288 using OpRewritePattern<IfOp>::OpRewritePattern; 2289 2290 LogicalResult matchAndRewrite(IfOp op, 2291 PatternRewriter &rewriter) const override { 2292 BoolAttr condition; 2293 if (!matchPattern(op.getCondition(), m_Constant(&condition))) 2294 return failure(); 2295 2296 if (condition.getValue()) 2297 replaceOpWithRegion(rewriter, op, op.getThenRegion()); 2298 else if (!op.getElseRegion().empty()) 2299 replaceOpWithRegion(rewriter, op, op.getElseRegion()); 2300 else 2301 rewriter.eraseOp(op); 2302 2303 return success(); 2304 } 2305 }; 2306 2307 /// Hoist any yielded results whose operands are defined outside 2308 /// the if, to a select instruction. 2309 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> { 2310 using OpRewritePattern<IfOp>::OpRewritePattern; 2311 2312 LogicalResult matchAndRewrite(IfOp op, 2313 PatternRewriter &rewriter) const override { 2314 if (op->getNumResults() == 0) 2315 return failure(); 2316 2317 auto cond = op.getCondition(); 2318 auto thenYieldArgs = op.thenYield().getOperands(); 2319 auto elseYieldArgs = op.elseYield().getOperands(); 2320 2321 SmallVector<Type> nonHoistable; 2322 for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) { 2323 if (&op.getThenRegion() == trueVal.getParentRegion() || 2324 &op.getElseRegion() == falseVal.getParentRegion()) 2325 nonHoistable.push_back(trueVal.getType()); 2326 } 2327 // Early exit if there aren't any yielded values we can 2328 // hoist outside the if. 2329 if (nonHoistable.size() == op->getNumResults()) 2330 return failure(); 2331 2332 IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond, 2333 /*withElseRegion=*/false); 2334 if (replacement.thenBlock()) 2335 rewriter.eraseBlock(replacement.thenBlock()); 2336 replacement.getThenRegion().takeBody(op.getThenRegion()); 2337 replacement.getElseRegion().takeBody(op.getElseRegion()); 2338 2339 SmallVector<Value> results(op->getNumResults()); 2340 assert(thenYieldArgs.size() == results.size()); 2341 assert(elseYieldArgs.size() == results.size()); 2342 2343 SmallVector<Value> trueYields; 2344 SmallVector<Value> falseYields; 2345 rewriter.setInsertionPoint(replacement); 2346 for (const auto &it : 2347 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) { 2348 Value trueVal = std::get<0>(it.value()); 2349 Value falseVal = std::get<1>(it.value()); 2350 if (&replacement.getThenRegion() == trueVal.getParentRegion() || 2351 &replacement.getElseRegion() == falseVal.getParentRegion()) { 2352 results[it.index()] = replacement.getResult(trueYields.size()); 2353 trueYields.push_back(trueVal); 2354 falseYields.push_back(falseVal); 2355 } else if (trueVal == falseVal) 2356 results[it.index()] = trueVal; 2357 else 2358 results[it.index()] = rewriter.create<arith::SelectOp>( 2359 op.getLoc(), cond, trueVal, falseVal); 2360 } 2361 2362 rewriter.setInsertionPointToEnd(replacement.thenBlock()); 2363 rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields); 2364 2365 rewriter.setInsertionPointToEnd(replacement.elseBlock()); 2366 rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields); 2367 2368 rewriter.replaceOp(op, results); 2369 return success(); 2370 } 2371 }; 2372 2373 /// Allow the true region of an if to assume the condition is true 2374 /// and vice versa. For example: 2375 /// 2376 /// scf.if %cmp { 2377 /// print(%cmp) 2378 /// } 2379 /// 2380 /// becomes 2381 /// 2382 /// scf.if %cmp { 2383 /// print(true) 2384 /// } 2385 /// 2386 struct ConditionPropagation : public OpRewritePattern<IfOp> { 2387 using OpRewritePattern<IfOp>::OpRewritePattern; 2388 2389 LogicalResult matchAndRewrite(IfOp op, 2390 PatternRewriter &rewriter) const override { 2391 // Early exit if the condition is constant since replacing a constant 2392 // in the body with another constant isn't a simplification. 2393 if (matchPattern(op.getCondition(), m_Constant())) 2394 return failure(); 2395 2396 bool changed = false; 2397 mlir::Type i1Ty = rewriter.getI1Type(); 2398 2399 // These variables serve to prevent creating duplicate constants 2400 // and hold constant true or false values. 2401 Value constantTrue = nullptr; 2402 Value constantFalse = nullptr; 2403 2404 for (OpOperand &use : 2405 llvm::make_early_inc_range(op.getCondition().getUses())) { 2406 if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) { 2407 changed = true; 2408 2409 if (!constantTrue) 2410 constantTrue = rewriter.create<arith::ConstantOp>( 2411 op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); 2412 2413 rewriter.modifyOpInPlace(use.getOwner(), 2414 [&]() { use.set(constantTrue); }); 2415 } else if (op.getElseRegion().isAncestor( 2416 use.getOwner()->getParentRegion())) { 2417 changed = true; 2418 2419 if (!constantFalse) 2420 constantFalse = rewriter.create<arith::ConstantOp>( 2421 op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); 2422 2423 rewriter.modifyOpInPlace(use.getOwner(), 2424 [&]() { use.set(constantFalse); }); 2425 } 2426 } 2427 2428 return success(changed); 2429 } 2430 }; 2431 2432 /// Remove any statements from an if that are equivalent to the condition 2433 /// or its negation. For example: 2434 /// 2435 /// %res:2 = scf.if %cmp { 2436 /// yield something(), true 2437 /// } else { 2438 /// yield something2(), false 2439 /// } 2440 /// print(%res#1) 2441 /// 2442 /// becomes 2443 /// %res = scf.if %cmp { 2444 /// yield something() 2445 /// } else { 2446 /// yield something2() 2447 /// } 2448 /// print(%cmp) 2449 /// 2450 /// Additionally if both branches yield the same value, replace all uses 2451 /// of the result with the yielded value. 2452 /// 2453 /// %res:2 = scf.if %cmp { 2454 /// yield something(), %arg1 2455 /// } else { 2456 /// yield something2(), %arg1 2457 /// } 2458 /// print(%res#1) 2459 /// 2460 /// becomes 2461 /// %res = scf.if %cmp { 2462 /// yield something() 2463 /// } else { 2464 /// yield something2() 2465 /// } 2466 /// print(%arg1) 2467 /// 2468 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> { 2469 using OpRewritePattern<IfOp>::OpRewritePattern; 2470 2471 LogicalResult matchAndRewrite(IfOp op, 2472 PatternRewriter &rewriter) const override { 2473 // Early exit if there are no results that could be replaced. 2474 if (op.getNumResults() == 0) 2475 return failure(); 2476 2477 auto trueYield = 2478 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator()); 2479 auto falseYield = 2480 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator()); 2481 2482 rewriter.setInsertionPoint(op->getBlock(), 2483 op.getOperation()->getIterator()); 2484 bool changed = false; 2485 Type i1Ty = rewriter.getI1Type(); 2486 for (auto [trueResult, falseResult, opResult] : 2487 llvm::zip(trueYield.getResults(), falseYield.getResults(), 2488 op.getResults())) { 2489 if (trueResult == falseResult) { 2490 if (!opResult.use_empty()) { 2491 opResult.replaceAllUsesWith(trueResult); 2492 changed = true; 2493 } 2494 continue; 2495 } 2496 2497 BoolAttr trueYield, falseYield; 2498 if (!matchPattern(trueResult, m_Constant(&trueYield)) || 2499 !matchPattern(falseResult, m_Constant(&falseYield))) 2500 continue; 2501 2502 bool trueVal = trueYield.getValue(); 2503 bool falseVal = falseYield.getValue(); 2504 if (!trueVal && falseVal) { 2505 if (!opResult.use_empty()) { 2506 Dialect *constDialect = trueResult.getDefiningOp()->getDialect(); 2507 Value notCond = rewriter.create<arith::XOrIOp>( 2508 op.getLoc(), op.getCondition(), 2509 constDialect 2510 ->materializeConstant(rewriter, 2511 rewriter.getIntegerAttr(i1Ty, 1), i1Ty, 2512 op.getLoc()) 2513 ->getResult(0)); 2514 opResult.replaceAllUsesWith(notCond); 2515 changed = true; 2516 } 2517 } 2518 if (trueVal && !falseVal) { 2519 if (!opResult.use_empty()) { 2520 opResult.replaceAllUsesWith(op.getCondition()); 2521 changed = true; 2522 } 2523 } 2524 } 2525 return success(changed); 2526 } 2527 }; 2528 2529 /// Merge any consecutive scf.if's with the same condition. 2530 /// 2531 /// scf.if %cond { 2532 /// firstCodeTrue();... 2533 /// } else { 2534 /// firstCodeFalse();... 2535 /// } 2536 /// %res = scf.if %cond { 2537 /// secondCodeTrue();... 2538 /// } else { 2539 /// secondCodeFalse();... 2540 /// } 2541 /// 2542 /// becomes 2543 /// %res = scf.if %cmp { 2544 /// firstCodeTrue();... 2545 /// secondCodeTrue();... 2546 /// } else { 2547 /// firstCodeFalse();... 2548 /// secondCodeFalse();... 2549 /// } 2550 struct CombineIfs : public OpRewritePattern<IfOp> { 2551 using OpRewritePattern<IfOp>::OpRewritePattern; 2552 2553 LogicalResult matchAndRewrite(IfOp nextIf, 2554 PatternRewriter &rewriter) const override { 2555 Block *parent = nextIf->getBlock(); 2556 if (nextIf == &parent->front()) 2557 return failure(); 2558 2559 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode()); 2560 if (!prevIf) 2561 return failure(); 2562 2563 // Determine the logical then/else blocks when prevIf's 2564 // condition is used. Null means the block does not exist 2565 // in that case (e.g. empty else). If neither of these 2566 // are set, the two conditions cannot be compared. 2567 Block *nextThen = nullptr; 2568 Block *nextElse = nullptr; 2569 if (nextIf.getCondition() == prevIf.getCondition()) { 2570 nextThen = nextIf.thenBlock(); 2571 if (!nextIf.getElseRegion().empty()) 2572 nextElse = nextIf.elseBlock(); 2573 } 2574 if (arith::XOrIOp notv = 2575 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) { 2576 if (notv.getLhs() == prevIf.getCondition() && 2577 matchPattern(notv.getRhs(), m_One())) { 2578 nextElse = nextIf.thenBlock(); 2579 if (!nextIf.getElseRegion().empty()) 2580 nextThen = nextIf.elseBlock(); 2581 } 2582 } 2583 if (arith::XOrIOp notv = 2584 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) { 2585 if (notv.getLhs() == nextIf.getCondition() && 2586 matchPattern(notv.getRhs(), m_One())) { 2587 nextElse = nextIf.thenBlock(); 2588 if (!nextIf.getElseRegion().empty()) 2589 nextThen = nextIf.elseBlock(); 2590 } 2591 } 2592 2593 if (!nextThen && !nextElse) 2594 return failure(); 2595 2596 SmallVector<Value> prevElseYielded; 2597 if (!prevIf.getElseRegion().empty()) 2598 prevElseYielded = prevIf.elseYield().getOperands(); 2599 // Replace all uses of return values of op within nextIf with the 2600 // corresponding yields 2601 for (auto it : llvm::zip(prevIf.getResults(), 2602 prevIf.thenYield().getOperands(), prevElseYielded)) 2603 for (OpOperand &use : 2604 llvm::make_early_inc_range(std::get<0>(it).getUses())) { 2605 if (nextThen && nextThen->getParent()->isAncestor( 2606 use.getOwner()->getParentRegion())) { 2607 rewriter.startOpModification(use.getOwner()); 2608 use.set(std::get<1>(it)); 2609 rewriter.finalizeOpModification(use.getOwner()); 2610 } else if (nextElse && nextElse->getParent()->isAncestor( 2611 use.getOwner()->getParentRegion())) { 2612 rewriter.startOpModification(use.getOwner()); 2613 use.set(std::get<2>(it)); 2614 rewriter.finalizeOpModification(use.getOwner()); 2615 } 2616 } 2617 2618 SmallVector<Type> mergedTypes(prevIf.getResultTypes()); 2619 llvm::append_range(mergedTypes, nextIf.getResultTypes()); 2620 2621 IfOp combinedIf = rewriter.create<IfOp>( 2622 nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false); 2623 rewriter.eraseBlock(&combinedIf.getThenRegion().back()); 2624 2625 rewriter.inlineRegionBefore(prevIf.getThenRegion(), 2626 combinedIf.getThenRegion(), 2627 combinedIf.getThenRegion().begin()); 2628 2629 if (nextThen) { 2630 YieldOp thenYield = combinedIf.thenYield(); 2631 YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator()); 2632 rewriter.mergeBlocks(nextThen, combinedIf.thenBlock()); 2633 rewriter.setInsertionPointToEnd(combinedIf.thenBlock()); 2634 2635 SmallVector<Value> mergedYields(thenYield.getOperands()); 2636 llvm::append_range(mergedYields, thenYield2.getOperands()); 2637 rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields); 2638 rewriter.eraseOp(thenYield); 2639 rewriter.eraseOp(thenYield2); 2640 } 2641 2642 rewriter.inlineRegionBefore(prevIf.getElseRegion(), 2643 combinedIf.getElseRegion(), 2644 combinedIf.getElseRegion().begin()); 2645 2646 if (nextElse) { 2647 if (combinedIf.getElseRegion().empty()) { 2648 rewriter.inlineRegionBefore(*nextElse->getParent(), 2649 combinedIf.getElseRegion(), 2650 combinedIf.getElseRegion().begin()); 2651 } else { 2652 YieldOp elseYield = combinedIf.elseYield(); 2653 YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator()); 2654 rewriter.mergeBlocks(nextElse, combinedIf.elseBlock()); 2655 2656 rewriter.setInsertionPointToEnd(combinedIf.elseBlock()); 2657 2658 SmallVector<Value> mergedElseYields(elseYield.getOperands()); 2659 llvm::append_range(mergedElseYields, elseYield2.getOperands()); 2660 2661 rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields); 2662 rewriter.eraseOp(elseYield); 2663 rewriter.eraseOp(elseYield2); 2664 } 2665 } 2666 2667 SmallVector<Value> prevValues; 2668 SmallVector<Value> nextValues; 2669 for (const auto &pair : llvm::enumerate(combinedIf.getResults())) { 2670 if (pair.index() < prevIf.getNumResults()) 2671 prevValues.push_back(pair.value()); 2672 else 2673 nextValues.push_back(pair.value()); 2674 } 2675 rewriter.replaceOp(prevIf, prevValues); 2676 rewriter.replaceOp(nextIf, nextValues); 2677 return success(); 2678 } 2679 }; 2680 2681 /// Pattern to remove an empty else branch. 2682 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> { 2683 using OpRewritePattern<IfOp>::OpRewritePattern; 2684 2685 LogicalResult matchAndRewrite(IfOp ifOp, 2686 PatternRewriter &rewriter) const override { 2687 // Cannot remove else region when there are operation results. 2688 if (ifOp.getNumResults()) 2689 return failure(); 2690 Block *elseBlock = ifOp.elseBlock(); 2691 if (!elseBlock || !llvm::hasSingleElement(*elseBlock)) 2692 return failure(); 2693 auto newIfOp = rewriter.cloneWithoutRegions(ifOp); 2694 rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(), 2695 newIfOp.getThenRegion().begin()); 2696 rewriter.eraseOp(ifOp); 2697 return success(); 2698 } 2699 }; 2700 2701 /// Convert nested `if`s into `arith.andi` + single `if`. 2702 /// 2703 /// scf.if %arg0 { 2704 /// scf.if %arg1 { 2705 /// ... 2706 /// scf.yield 2707 /// } 2708 /// scf.yield 2709 /// } 2710 /// becomes 2711 /// 2712 /// %0 = arith.andi %arg0, %arg1 2713 /// scf.if %0 { 2714 /// ... 2715 /// scf.yield 2716 /// } 2717 struct CombineNestedIfs : public OpRewritePattern<IfOp> { 2718 using OpRewritePattern<IfOp>::OpRewritePattern; 2719 2720 LogicalResult matchAndRewrite(IfOp op, 2721 PatternRewriter &rewriter) const override { 2722 auto nestedOps = op.thenBlock()->without_terminator(); 2723 // Nested `if` must be the only op in block. 2724 if (!llvm::hasSingleElement(nestedOps)) 2725 return failure(); 2726 2727 // If there is an else block, it can only yield 2728 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock())) 2729 return failure(); 2730 2731 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin()); 2732 if (!nestedIf) 2733 return failure(); 2734 2735 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock())) 2736 return failure(); 2737 2738 SmallVector<Value> thenYield(op.thenYield().getOperands()); 2739 SmallVector<Value> elseYield; 2740 if (op.elseBlock()) 2741 llvm::append_range(elseYield, op.elseYield().getOperands()); 2742 2743 // A list of indices for which we should upgrade the value yielded 2744 // in the else to a select. 2745 SmallVector<unsigned> elseYieldsToUpgradeToSelect; 2746 2747 // If the outer scf.if yields a value produced by the inner scf.if, 2748 // only permit combining if the value yielded when the condition 2749 // is false in the outer scf.if is the same value yielded when the 2750 // inner scf.if condition is false. 2751 // Note that the array access to elseYield will not go out of bounds 2752 // since it must have the same length as thenYield, since they both 2753 // come from the same scf.if. 2754 for (const auto &tup : llvm::enumerate(thenYield)) { 2755 if (tup.value().getDefiningOp() == nestedIf) { 2756 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber(); 2757 if (nestedIf.elseYield().getOperand(nestedIdx) != 2758 elseYield[tup.index()]) { 2759 return failure(); 2760 } 2761 // If the correctness test passes, we will yield 2762 // corresponding value from the inner scf.if 2763 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx); 2764 continue; 2765 } 2766 2767 // Otherwise, we need to ensure the else block of the combined 2768 // condition still returns the same value when the outer condition is 2769 // true and the inner condition is false. This can be accomplished if 2770 // the then value is defined outside the outer scf.if and we replace the 2771 // value with a select that considers just the outer condition. Since 2772 // the else region contains just the yield, its yielded value is 2773 // defined outside the scf.if, by definition. 2774 2775 // If the then value is defined within the scf.if, bail. 2776 if (tup.value().getParentRegion() == &op.getThenRegion()) { 2777 return failure(); 2778 } 2779 elseYieldsToUpgradeToSelect.push_back(tup.index()); 2780 } 2781 2782 Location loc = op.getLoc(); 2783 Value newCondition = rewriter.create<arith::AndIOp>( 2784 loc, op.getCondition(), nestedIf.getCondition()); 2785 auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition); 2786 Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion()); 2787 2788 SmallVector<Value> results; 2789 llvm::append_range(results, newIf.getResults()); 2790 rewriter.setInsertionPoint(newIf); 2791 2792 for (auto idx : elseYieldsToUpgradeToSelect) 2793 results[idx] = rewriter.create<arith::SelectOp>( 2794 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]); 2795 2796 rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock); 2797 rewriter.setInsertionPointToEnd(newIf.thenBlock()); 2798 rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield); 2799 if (!elseYield.empty()) { 2800 rewriter.createBlock(&newIf.getElseRegion()); 2801 rewriter.setInsertionPointToEnd(newIf.elseBlock()); 2802 rewriter.create<YieldOp>(loc, elseYield); 2803 } 2804 rewriter.replaceOp(op, results); 2805 return success(); 2806 } 2807 }; 2808 2809 } // namespace 2810 2811 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, 2812 MLIRContext *context) { 2813 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation, 2814 ConvertTrivialIfToSelect, RemoveEmptyElseBranch, 2815 RemoveStaticCondition, RemoveUnusedResults, 2816 ReplaceIfYieldWithConditionOrValue>(context); 2817 } 2818 2819 Block *IfOp::thenBlock() { return &getThenRegion().back(); } 2820 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); } 2821 Block *IfOp::elseBlock() { 2822 Region &r = getElseRegion(); 2823 if (r.empty()) 2824 return nullptr; 2825 return &r.back(); 2826 } 2827 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); } 2828 2829 //===----------------------------------------------------------------------===// 2830 // ParallelOp 2831 //===----------------------------------------------------------------------===// 2832 2833 void ParallelOp::build( 2834 OpBuilder &builder, OperationState &result, ValueRange lowerBounds, 2835 ValueRange upperBounds, ValueRange steps, ValueRange initVals, 2836 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> 2837 bodyBuilderFn) { 2838 result.addOperands(lowerBounds); 2839 result.addOperands(upperBounds); 2840 result.addOperands(steps); 2841 result.addOperands(initVals); 2842 result.addAttribute( 2843 ParallelOp::getOperandSegmentSizeAttr(), 2844 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()), 2845 static_cast<int32_t>(upperBounds.size()), 2846 static_cast<int32_t>(steps.size()), 2847 static_cast<int32_t>(initVals.size())})); 2848 result.addTypes(initVals.getTypes()); 2849 2850 OpBuilder::InsertionGuard guard(builder); 2851 unsigned numIVs = steps.size(); 2852 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType()); 2853 SmallVector<Location, 8> argLocs(numIVs, result.location); 2854 Region *bodyRegion = result.addRegion(); 2855 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs); 2856 2857 if (bodyBuilderFn) { 2858 builder.setInsertionPointToStart(bodyBlock); 2859 bodyBuilderFn(builder, result.location, 2860 bodyBlock->getArguments().take_front(numIVs), 2861 bodyBlock->getArguments().drop_front(numIVs)); 2862 } 2863 // Add terminator only if there are no reductions. 2864 if (initVals.empty()) 2865 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location); 2866 } 2867 2868 void ParallelOp::build( 2869 OpBuilder &builder, OperationState &result, ValueRange lowerBounds, 2870 ValueRange upperBounds, ValueRange steps, 2871 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { 2872 // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure 2873 // we don't capture a reference to a temporary by constructing the lambda at 2874 // function level. 2875 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder, 2876 Location nestedLoc, ValueRange ivs, 2877 ValueRange) { 2878 bodyBuilderFn(nestedBuilder, nestedLoc, ivs); 2879 }; 2880 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper; 2881 if (bodyBuilderFn) 2882 wrapper = wrappedBuilderFn; 2883 2884 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(), 2885 wrapper); 2886 } 2887 2888 LogicalResult ParallelOp::verify() { 2889 // Check that there is at least one value in lowerBound, upperBound and step. 2890 // It is sufficient to test only step, because it is ensured already that the 2891 // number of elements in lowerBound, upperBound and step are the same. 2892 Operation::operand_range stepValues = getStep(); 2893 if (stepValues.empty()) 2894 return emitOpError( 2895 "needs at least one tuple element for lowerBound, upperBound and step"); 2896 2897 // Check whether all constant step values are positive. 2898 for (Value stepValue : stepValues) 2899 if (auto cst = getConstantIntValue(stepValue)) 2900 if (*cst <= 0) 2901 return emitOpError("constant step operand must be positive"); 2902 2903 // Check that the body defines the same number of block arguments as the 2904 // number of tuple elements in step. 2905 Block *body = getBody(); 2906 if (body->getNumArguments() != stepValues.size()) 2907 return emitOpError() << "expects the same number of induction variables: " 2908 << body->getNumArguments() 2909 << " as bound and step values: " << stepValues.size(); 2910 for (auto arg : body->getArguments()) 2911 if (!arg.getType().isIndex()) 2912 return emitOpError( 2913 "expects arguments for the induction variable to be of index type"); 2914 2915 // Check that the terminator is an scf.reduce op. 2916 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>( 2917 *this, getRegion(), "expects body to terminate with 'scf.reduce'"); 2918 if (!reduceOp) 2919 return failure(); 2920 2921 // Check that the number of results is the same as the number of reductions. 2922 auto resultsSize = getResults().size(); 2923 auto reductionsSize = reduceOp.getReductions().size(); 2924 auto initValsSize = getInitVals().size(); 2925 if (resultsSize != reductionsSize) 2926 return emitOpError() << "expects number of results: " << resultsSize 2927 << " to be the same as number of reductions: " 2928 << reductionsSize; 2929 if (resultsSize != initValsSize) 2930 return emitOpError() << "expects number of results: " << resultsSize 2931 << " to be the same as number of initial values: " 2932 << initValsSize; 2933 2934 // Check that the types of the results and reductions are the same. 2935 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) { 2936 auto resultType = getOperation()->getResult(i).getType(); 2937 auto reductionOperandType = reduceOp.getOperands()[i].getType(); 2938 if (resultType != reductionOperandType) 2939 return reduceOp.emitOpError() 2940 << "expects type of " << i 2941 << "-th reduction operand: " << reductionOperandType 2942 << " to be the same as the " << i 2943 << "-th result type: " << resultType; 2944 } 2945 return success(); 2946 } 2947 2948 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { 2949 auto &builder = parser.getBuilder(); 2950 // Parse an opening `(` followed by induction variables followed by `)` 2951 SmallVector<OpAsmParser::Argument, 4> ivs; 2952 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren)) 2953 return failure(); 2954 2955 // Parse loop bounds. 2956 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower; 2957 if (parser.parseEqual() || 2958 parser.parseOperandList(lower, ivs.size(), 2959 OpAsmParser::Delimiter::Paren) || 2960 parser.resolveOperands(lower, builder.getIndexType(), result.operands)) 2961 return failure(); 2962 2963 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper; 2964 if (parser.parseKeyword("to") || 2965 parser.parseOperandList(upper, ivs.size(), 2966 OpAsmParser::Delimiter::Paren) || 2967 parser.resolveOperands(upper, builder.getIndexType(), result.operands)) 2968 return failure(); 2969 2970 // Parse step values. 2971 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps; 2972 if (parser.parseKeyword("step") || 2973 parser.parseOperandList(steps, ivs.size(), 2974 OpAsmParser::Delimiter::Paren) || 2975 parser.resolveOperands(steps, builder.getIndexType(), result.operands)) 2976 return failure(); 2977 2978 // Parse init values. 2979 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals; 2980 if (succeeded(parser.parseOptionalKeyword("init"))) { 2981 if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren)) 2982 return failure(); 2983 } 2984 2985 // Parse optional results in case there is a reduce. 2986 if (parser.parseOptionalArrowTypeList(result.types)) 2987 return failure(); 2988 2989 // Now parse the body. 2990 Region *body = result.addRegion(); 2991 for (auto &iv : ivs) 2992 iv.type = builder.getIndexType(); 2993 if (parser.parseRegion(*body, ivs)) 2994 return failure(); 2995 2996 // Set `operandSegmentSizes` attribute. 2997 result.addAttribute( 2998 ParallelOp::getOperandSegmentSizeAttr(), 2999 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()), 3000 static_cast<int32_t>(upper.size()), 3001 static_cast<int32_t>(steps.size()), 3002 static_cast<int32_t>(initVals.size())})); 3003 3004 // Parse attributes. 3005 if (parser.parseOptionalAttrDict(result.attributes) || 3006 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 3007 result.operands)) 3008 return failure(); 3009 3010 // Add a terminator if none was parsed. 3011 ParallelOp::ensureTerminator(*body, builder, result.location); 3012 return success(); 3013 } 3014 3015 void ParallelOp::print(OpAsmPrinter &p) { 3016 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound() 3017 << ") to (" << getUpperBound() << ") step (" << getStep() << ")"; 3018 if (!getInitVals().empty()) 3019 p << " init (" << getInitVals() << ")"; 3020 p.printOptionalArrowTypeList(getResultTypes()); 3021 p << ' '; 3022 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 3023 p.printOptionalAttrDict( 3024 (*this)->getAttrs(), 3025 /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr()); 3026 } 3027 3028 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; } 3029 3030 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() { 3031 return SmallVector<Value>{getBody()->getArguments()}; 3032 } 3033 3034 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() { 3035 return getLowerBound(); 3036 } 3037 3038 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() { 3039 return getUpperBound(); 3040 } 3041 3042 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() { 3043 return getStep(); 3044 } 3045 3046 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) { 3047 auto ivArg = llvm::dyn_cast<BlockArgument>(val); 3048 if (!ivArg) 3049 return ParallelOp(); 3050 assert(ivArg.getOwner() && "unlinked block argument"); 3051 auto *containingOp = ivArg.getOwner()->getParentOp(); 3052 return dyn_cast<ParallelOp>(containingOp); 3053 } 3054 3055 namespace { 3056 // Collapse loop dimensions that perform a single iteration. 3057 struct ParallelOpSingleOrZeroIterationDimsFolder 3058 : public OpRewritePattern<ParallelOp> { 3059 using OpRewritePattern<ParallelOp>::OpRewritePattern; 3060 3061 LogicalResult matchAndRewrite(ParallelOp op, 3062 PatternRewriter &rewriter) const override { 3063 Location loc = op.getLoc(); 3064 3065 // Compute new loop bounds that omit all single-iteration loop dimensions. 3066 SmallVector<Value> newLowerBounds, newUpperBounds, newSteps; 3067 IRMapping mapping; 3068 for (auto [lb, ub, step, iv] : 3069 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(), 3070 op.getInductionVars())) { 3071 auto numIterations = constantTripCount(lb, ub, step); 3072 if (numIterations.has_value()) { 3073 // Remove the loop if it performs zero iterations. 3074 if (*numIterations == 0) { 3075 rewriter.replaceOp(op, op.getInitVals()); 3076 return success(); 3077 } 3078 // Replace the loop induction variable by the lower bound if the loop 3079 // performs a single iteration. Otherwise, copy the loop bounds. 3080 if (*numIterations == 1) { 3081 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb)); 3082 continue; 3083 } 3084 } 3085 newLowerBounds.push_back(lb); 3086 newUpperBounds.push_back(ub); 3087 newSteps.push_back(step); 3088 } 3089 // Exit if none of the loop dimensions perform a single iteration. 3090 if (newLowerBounds.size() == op.getLowerBound().size()) 3091 return failure(); 3092 3093 if (newLowerBounds.empty()) { 3094 // All of the loop dimensions perform a single iteration. Inline 3095 // loop body and nested ReduceOp's 3096 SmallVector<Value> results; 3097 results.reserve(op.getInitVals().size()); 3098 for (auto &bodyOp : op.getBody()->without_terminator()) 3099 rewriter.clone(bodyOp, mapping); 3100 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator()); 3101 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) { 3102 Block &reduceBlock = reduceOp.getReductions()[i].front(); 3103 auto initValIndex = results.size(); 3104 mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]); 3105 mapping.map(reduceBlock.getArgument(1), 3106 mapping.lookupOrDefault(reduceOp.getOperands()[i])); 3107 for (auto &reduceBodyOp : reduceBlock.without_terminator()) 3108 rewriter.clone(reduceBodyOp, mapping); 3109 3110 auto result = mapping.lookupOrDefault( 3111 cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult()); 3112 results.push_back(result); 3113 } 3114 3115 rewriter.replaceOp(op, results); 3116 return success(); 3117 } 3118 // Replace the parallel loop by lower-dimensional parallel loop. 3119 auto newOp = 3120 rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds, 3121 newSteps, op.getInitVals(), nullptr); 3122 // Erase the empty block that was inserted by the builder. 3123 rewriter.eraseBlock(newOp.getBody()); 3124 // Clone the loop body and remap the block arguments of the collapsed loops 3125 // (inlining does not support a cancellable block argument mapping). 3126 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(), 3127 newOp.getRegion().begin(), mapping); 3128 rewriter.replaceOp(op, newOp.getResults()); 3129 return success(); 3130 } 3131 }; 3132 3133 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> { 3134 using OpRewritePattern<ParallelOp>::OpRewritePattern; 3135 3136 LogicalResult matchAndRewrite(ParallelOp op, 3137 PatternRewriter &rewriter) const override { 3138 Block &outerBody = *op.getBody(); 3139 if (!llvm::hasSingleElement(outerBody.without_terminator())) 3140 return failure(); 3141 3142 auto innerOp = dyn_cast<ParallelOp>(outerBody.front()); 3143 if (!innerOp) 3144 return failure(); 3145 3146 for (auto val : outerBody.getArguments()) 3147 if (llvm::is_contained(innerOp.getLowerBound(), val) || 3148 llvm::is_contained(innerOp.getUpperBound(), val) || 3149 llvm::is_contained(innerOp.getStep(), val)) 3150 return failure(); 3151 3152 // Reductions are not supported yet. 3153 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty()) 3154 return failure(); 3155 3156 auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/, 3157 ValueRange iterVals, ValueRange) { 3158 Block &innerBody = *innerOp.getBody(); 3159 assert(iterVals.size() == 3160 (outerBody.getNumArguments() + innerBody.getNumArguments())); 3161 IRMapping mapping; 3162 mapping.map(outerBody.getArguments(), 3163 iterVals.take_front(outerBody.getNumArguments())); 3164 mapping.map(innerBody.getArguments(), 3165 iterVals.take_back(innerBody.getNumArguments())); 3166 for (Operation &op : innerBody.without_terminator()) 3167 builder.clone(op, mapping); 3168 }; 3169 3170 auto concatValues = [](const auto &first, const auto &second) { 3171 SmallVector<Value> ret; 3172 ret.reserve(first.size() + second.size()); 3173 ret.assign(first.begin(), first.end()); 3174 ret.append(second.begin(), second.end()); 3175 return ret; 3176 }; 3177 3178 auto newLowerBounds = 3179 concatValues(op.getLowerBound(), innerOp.getLowerBound()); 3180 auto newUpperBounds = 3181 concatValues(op.getUpperBound(), innerOp.getUpperBound()); 3182 auto newSteps = concatValues(op.getStep(), innerOp.getStep()); 3183 3184 rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds, 3185 newSteps, std::nullopt, 3186 bodyBuilder); 3187 return success(); 3188 } 3189 }; 3190 3191 } // namespace 3192 3193 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results, 3194 MLIRContext *context) { 3195 results 3196 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>( 3197 context); 3198 } 3199 3200 /// Given the region at `index`, or the parent operation if `index` is None, 3201 /// return the successor regions. These are the regions that may be selected 3202 /// during the flow of control. `operands` is a set of optional attributes that 3203 /// correspond to a constant value for each operand, or null if that operand is 3204 /// not a constant. 3205 void ParallelOp::getSuccessorRegions( 3206 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 3207 // Both the operation itself and the region may be branching into the body or 3208 // back into the operation itself. It is possible for loop not to enter the 3209 // body. 3210 regions.push_back(RegionSuccessor(&getRegion())); 3211 regions.push_back(RegionSuccessor()); 3212 } 3213 3214 //===----------------------------------------------------------------------===// 3215 // ReduceOp 3216 //===----------------------------------------------------------------------===// 3217 3218 void ReduceOp::build(OpBuilder &builder, OperationState &result) {} 3219 3220 void ReduceOp::build(OpBuilder &builder, OperationState &result, 3221 ValueRange operands) { 3222 result.addOperands(operands); 3223 for (Value v : operands) { 3224 OpBuilder::InsertionGuard guard(builder); 3225 Region *bodyRegion = result.addRegion(); 3226 builder.createBlock(bodyRegion, {}, 3227 ArrayRef<Type>{v.getType(), v.getType()}, 3228 {result.location, result.location}); 3229 } 3230 } 3231 3232 LogicalResult ReduceOp::verifyRegions() { 3233 // The region of a ReduceOp has two arguments of the same type as its 3234 // corresponding operand. 3235 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) { 3236 auto type = getOperands()[i].getType(); 3237 Block &block = getReductions()[i].front(); 3238 if (block.empty()) 3239 return emitOpError() << i << "-th reduction has an empty body"; 3240 if (block.getNumArguments() != 2 || 3241 llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) { 3242 return arg.getType() != type; 3243 })) 3244 return emitOpError() << "expected two block arguments with type " << type 3245 << " in the " << i << "-th reduction region"; 3246 3247 // Check that the block is terminated by a ReduceReturnOp. 3248 if (!isa<ReduceReturnOp>(block.getTerminator())) 3249 return emitOpError("reduction bodies must be terminated with an " 3250 "'scf.reduce.return' op"); 3251 } 3252 3253 return success(); 3254 } 3255 3256 MutableOperandRange 3257 ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { 3258 // No operands are forwarded to the next iteration. 3259 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); 3260 } 3261 3262 //===----------------------------------------------------------------------===// 3263 // ReduceReturnOp 3264 //===----------------------------------------------------------------------===// 3265 3266 LogicalResult ReduceReturnOp::verify() { 3267 // The type of the return value should be the same type as the types of the 3268 // block arguments of the reduction body. 3269 Block *reductionBody = getOperation()->getBlock(); 3270 // Should already be verified by an op trait. 3271 assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce"); 3272 Type expectedResultType = reductionBody->getArgument(0).getType(); 3273 if (expectedResultType != getResult().getType()) 3274 return emitOpError() << "must have type " << expectedResultType 3275 << " (the type of the reduction inputs)"; 3276 return success(); 3277 } 3278 3279 //===----------------------------------------------------------------------===// 3280 // WhileOp 3281 //===----------------------------------------------------------------------===// 3282 3283 void WhileOp::build(::mlir::OpBuilder &odsBuilder, 3284 ::mlir::OperationState &odsState, TypeRange resultTypes, 3285 ValueRange inits, BodyBuilderFn beforeBuilder, 3286 BodyBuilderFn afterBuilder) { 3287 odsState.addOperands(inits); 3288 odsState.addTypes(resultTypes); 3289 3290 OpBuilder::InsertionGuard guard(odsBuilder); 3291 3292 // Build before region. 3293 SmallVector<Location, 4> beforeArgLocs; 3294 beforeArgLocs.reserve(inits.size()); 3295 for (Value operand : inits) { 3296 beforeArgLocs.push_back(operand.getLoc()); 3297 } 3298 3299 Region *beforeRegion = odsState.addRegion(); 3300 Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{}, 3301 inits.getTypes(), beforeArgLocs); 3302 if (beforeBuilder) 3303 beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments()); 3304 3305 // Build after region. 3306 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location); 3307 3308 Region *afterRegion = odsState.addRegion(); 3309 Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{}, 3310 resultTypes, afterArgLocs); 3311 3312 if (afterBuilder) 3313 afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); 3314 } 3315 3316 ConditionOp WhileOp::getConditionOp() { 3317 return cast<ConditionOp>(getBeforeBody()->getTerminator()); 3318 } 3319 3320 YieldOp WhileOp::getYieldOp() { 3321 return cast<YieldOp>(getAfterBody()->getTerminator()); 3322 } 3323 3324 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() { 3325 return getYieldOp().getResultsMutable(); 3326 } 3327 3328 Block::BlockArgListType WhileOp::getBeforeArguments() { 3329 return getBeforeBody()->getArguments(); 3330 } 3331 3332 Block::BlockArgListType WhileOp::getAfterArguments() { 3333 return getAfterBody()->getArguments(); 3334 } 3335 3336 Block::BlockArgListType WhileOp::getRegionIterArgs() { 3337 return getBeforeArguments(); 3338 } 3339 3340 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { 3341 assert(point == getBefore() && 3342 "WhileOp is expected to branch only to the first region"); 3343 return getInits(); 3344 } 3345 3346 void WhileOp::getSuccessorRegions(RegionBranchPoint point, 3347 SmallVectorImpl<RegionSuccessor> ®ions) { 3348 // The parent op always branches to the condition region. 3349 if (point.isParent()) { 3350 regions.emplace_back(&getBefore(), getBefore().getArguments()); 3351 return; 3352 } 3353 3354 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && 3355 "there are only two regions in a WhileOp"); 3356 // The body region always branches back to the condition region. 3357 if (point == getAfter()) { 3358 regions.emplace_back(&getBefore(), getBefore().getArguments()); 3359 return; 3360 } 3361 3362 regions.emplace_back(getResults()); 3363 regions.emplace_back(&getAfter(), getAfter().getArguments()); 3364 } 3365 3366 SmallVector<Region *> WhileOp::getLoopRegions() { 3367 return {&getBefore(), &getAfter()}; 3368 } 3369 3370 /// Parses a `while` op. 3371 /// 3372 /// op ::= `scf.while` assignments `:` function-type region `do` region 3373 /// `attributes` attribute-dict 3374 /// initializer ::= /* empty */ | `(` assignment-list `)` 3375 /// assignment-list ::= assignment | assignment `,` assignment-list 3376 /// assignment ::= ssa-value `=` ssa-value 3377 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) { 3378 SmallVector<OpAsmParser::Argument, 4> regionArgs; 3379 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; 3380 Region *before = result.addRegion(); 3381 Region *after = result.addRegion(); 3382 3383 OptionalParseResult listResult = 3384 parser.parseOptionalAssignmentList(regionArgs, operands); 3385 if (listResult.has_value() && failed(listResult.value())) 3386 return failure(); 3387 3388 FunctionType functionType; 3389 SMLoc typeLoc = parser.getCurrentLocation(); 3390 if (failed(parser.parseColonType(functionType))) 3391 return failure(); 3392 3393 result.addTypes(functionType.getResults()); 3394 3395 if (functionType.getNumInputs() != operands.size()) { 3396 return parser.emitError(typeLoc) 3397 << "expected as many input types as operands " 3398 << "(expected " << operands.size() << " got " 3399 << functionType.getNumInputs() << ")"; 3400 } 3401 3402 // Resolve input operands. 3403 if (failed(parser.resolveOperands(operands, functionType.getInputs(), 3404 parser.getCurrentLocation(), 3405 result.operands))) 3406 return failure(); 3407 3408 // Propagate the types into the region arguments. 3409 for (size_t i = 0, e = regionArgs.size(); i != e; ++i) 3410 regionArgs[i].type = functionType.getInput(i); 3411 3412 return failure(parser.parseRegion(*before, regionArgs) || 3413 parser.parseKeyword("do") || parser.parseRegion(*after) || 3414 parser.parseOptionalAttrDictWithKeyword(result.attributes)); 3415 } 3416 3417 /// Prints a `while` op. 3418 void scf::WhileOp::print(OpAsmPrinter &p) { 3419 printInitializationList(p, getBeforeArguments(), getInits(), " "); 3420 p << " : "; 3421 p.printFunctionalType(getInits().getTypes(), getResults().getTypes()); 3422 p << ' '; 3423 p.printRegion(getBefore(), /*printEntryBlockArgs=*/false); 3424 p << " do "; 3425 p.printRegion(getAfter()); 3426 p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 3427 } 3428 3429 /// Verifies that two ranges of types match, i.e. have the same number of 3430 /// entries and that types are pairwise equals. Reports errors on the given 3431 /// operation in case of mismatch. 3432 template <typename OpTy> 3433 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, 3434 TypeRange right, StringRef message) { 3435 if (left.size() != right.size()) 3436 return op.emitOpError("expects the same number of ") << message; 3437 3438 for (unsigned i = 0, e = left.size(); i < e; ++i) { 3439 if (left[i] != right[i]) { 3440 InFlightDiagnostic diag = op.emitOpError("expects the same types for ") 3441 << message; 3442 diag.attachNote() << "for argument " << i << ", found " << left[i] 3443 << " and " << right[i]; 3444 return diag; 3445 } 3446 } 3447 3448 return success(); 3449 } 3450 3451 LogicalResult scf::WhileOp::verify() { 3452 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>( 3453 *this, getBefore(), 3454 "expects the 'before' region to terminate with 'scf.condition'"); 3455 if (!beforeTerminator) 3456 return failure(); 3457 3458 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>( 3459 *this, getAfter(), 3460 "expects the 'after' region to terminate with 'scf.yield'"); 3461 return success(afterTerminator != nullptr); 3462 } 3463 3464 namespace { 3465 /// Replace uses of the condition within the do block with true, since otherwise 3466 /// the block would not be evaluated. 3467 /// 3468 /// scf.while (..) : (i1, ...) -> ... { 3469 /// %condition = call @evaluate_condition() : () -> i1 3470 /// scf.condition(%condition) %condition : i1, ... 3471 /// } do { 3472 /// ^bb0(%arg0: i1, ...): 3473 /// use(%arg0) 3474 /// ... 3475 /// 3476 /// becomes 3477 /// scf.while (..) : (i1, ...) -> ... { 3478 /// %condition = call @evaluate_condition() : () -> i1 3479 /// scf.condition(%condition) %condition : i1, ... 3480 /// } do { 3481 /// ^bb0(%arg0: i1, ...): 3482 /// use(%true) 3483 /// ... 3484 struct WhileConditionTruth : public OpRewritePattern<WhileOp> { 3485 using OpRewritePattern<WhileOp>::OpRewritePattern; 3486 3487 LogicalResult matchAndRewrite(WhileOp op, 3488 PatternRewriter &rewriter) const override { 3489 auto term = op.getConditionOp(); 3490 3491 // These variables serve to prevent creating duplicate constants 3492 // and hold constant true or false values. 3493 Value constantTrue = nullptr; 3494 3495 bool replaced = false; 3496 for (auto yieldedAndBlockArgs : 3497 llvm::zip(term.getArgs(), op.getAfterArguments())) { 3498 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) { 3499 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) { 3500 if (!constantTrue) 3501 constantTrue = rewriter.create<arith::ConstantOp>( 3502 op.getLoc(), term.getCondition().getType(), 3503 rewriter.getBoolAttr(true)); 3504 3505 rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs), 3506 constantTrue); 3507 replaced = true; 3508 } 3509 } 3510 } 3511 return success(replaced); 3512 } 3513 }; 3514 3515 /// Remove loop invariant arguments from `before` block of scf.while. 3516 /// A before block argument is considered loop invariant if :- 3517 /// 1. i-th yield operand is equal to the i-th while operand. 3518 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th 3519 /// condition operand AND this (k+1)-th condition operand is equal to i-th 3520 /// iter argument/while operand. 3521 /// For the arguments which are removed, their uses inside scf.while 3522 /// are replaced with their corresponding initial value. 3523 /// 3524 /// Eg: 3525 /// INPUT :- 3526 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b, 3527 /// ..., %argN_before = %N) 3528 /// { 3529 /// ... 3530 /// scf.condition(%cond) %arg1_before, %arg0_before, 3531 /// %arg2_before, %arg0_before, ... 3532 /// } do { 3533 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, 3534 /// ..., %argK_after): 3535 /// ... 3536 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN 3537 /// } 3538 /// 3539 /// OUTPUT :- 3540 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before = 3541 /// %N) 3542 /// { 3543 /// ... 3544 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ... 3545 /// } do { 3546 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, 3547 /// ..., %argK_after): 3548 /// ... 3549 /// scf.yield %arg1_after, ..., %argN 3550 /// } 3551 /// 3552 /// EXPLANATION: 3553 /// We iterate over each yield operand. 3554 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand 3555 /// %arg0_before, which in turn is the 0-th iter argument. So we 3556 /// remove 0-th before block argument and yield operand, and replace 3557 /// all uses of the 0-th before block argument with its initial value 3558 /// %a. 3559 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial 3560 /// value. So we remove this operand and the corresponding before 3561 /// block argument and replace all uses of 1-th before block argument 3562 /// with %b. 3563 struct RemoveLoopInvariantArgsFromBeforeBlock 3564 : public OpRewritePattern<WhileOp> { 3565 using OpRewritePattern<WhileOp>::OpRewritePattern; 3566 3567 LogicalResult matchAndRewrite(WhileOp op, 3568 PatternRewriter &rewriter) const override { 3569 Block &afterBlock = *op.getAfterBody(); 3570 Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments(); 3571 ConditionOp condOp = op.getConditionOp(); 3572 OperandRange condOpArgs = condOp.getArgs(); 3573 Operation *yieldOp = afterBlock.getTerminator(); 3574 ValueRange yieldOpArgs = yieldOp->getOperands(); 3575 3576 bool canSimplify = false; 3577 for (const auto &it : 3578 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { 3579 auto index = static_cast<unsigned>(it.index()); 3580 auto [initVal, yieldOpArg] = it.value(); 3581 // If i-th yield operand is equal to the i-th operand of the scf.while, 3582 // the i-th before block argument is a loop invariant. 3583 if (yieldOpArg == initVal) { 3584 canSimplify = true; 3585 break; 3586 } 3587 // If the i-th yield operand is k-th after block argument, then we check 3588 // if the (k+1)-th condition op operand is equal to either the i-th before 3589 // block argument or the initial value of i-th before block argument. If 3590 // the comparison results `true`, i-th before block argument is a loop 3591 // invariant. 3592 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); 3593 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { 3594 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; 3595 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { 3596 canSimplify = true; 3597 break; 3598 } 3599 } 3600 } 3601 3602 if (!canSimplify) 3603 return failure(); 3604 3605 SmallVector<Value> newInitArgs, newYieldOpArgs; 3606 DenseMap<unsigned, Value> beforeBlockInitValMap; 3607 SmallVector<Location> newBeforeBlockArgLocs; 3608 for (const auto &it : 3609 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { 3610 auto index = static_cast<unsigned>(it.index()); 3611 auto [initVal, yieldOpArg] = it.value(); 3612 3613 // If i-th yield operand is equal to the i-th operand of the scf.while, 3614 // the i-th before block argument is a loop invariant. 3615 if (yieldOpArg == initVal) { 3616 beforeBlockInitValMap.insert({index, initVal}); 3617 continue; 3618 } else { 3619 // If the i-th yield operand is k-th after block argument, then we check 3620 // if the (k+1)-th condition op operand is equal to either the i-th 3621 // before block argument or the initial value of i-th before block 3622 // argument. If the comparison results `true`, i-th before block 3623 // argument is a loop invariant. 3624 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); 3625 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { 3626 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; 3627 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { 3628 beforeBlockInitValMap.insert({index, initVal}); 3629 continue; 3630 } 3631 } 3632 } 3633 newInitArgs.emplace_back(initVal); 3634 newYieldOpArgs.emplace_back(yieldOpArg); 3635 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc()); 3636 } 3637 3638 { 3639 OpBuilder::InsertionGuard g(rewriter); 3640 rewriter.setInsertionPoint(yieldOp); 3641 rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs); 3642 } 3643 3644 auto newWhile = 3645 rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs); 3646 3647 Block &newBeforeBlock = *rewriter.createBlock( 3648 &newWhile.getBefore(), /*insertPt*/ {}, 3649 ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs); 3650 3651 Block &beforeBlock = *op.getBeforeBody(); 3652 SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments()); 3653 // For each i-th before block argument we find it's replacement value as :- 3654 // 1. If i-th before block argument is a loop invariant, we fetch it's 3655 // initial value from `beforeBlockInitValMap` by querying for key `i`. 3656 // 2. Else we fetch j-th new before block argument as the replacement 3657 // value of i-th before block argument. 3658 for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) { 3659 // If the index 'i' argument was a loop invariant we fetch it's initial 3660 // value from `beforeBlockInitValMap`. 3661 if (beforeBlockInitValMap.count(i) != 0) 3662 newBeforeBlockArgs[i] = beforeBlockInitValMap[i]; 3663 else 3664 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++); 3665 } 3666 3667 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs); 3668 rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(), 3669 newWhile.getAfter().begin()); 3670 3671 rewriter.replaceOp(op, newWhile.getResults()); 3672 return success(); 3673 } 3674 }; 3675 3676 /// Remove loop invariant value from result (condition op) of scf.while. 3677 /// A value is considered loop invariant if the final value yielded by 3678 /// scf.condition is defined outside of the `before` block. We remove the 3679 /// corresponding argument in `after` block and replace the use with the value. 3680 /// We also replace the use of the corresponding result of scf.while with the 3681 /// value. 3682 /// 3683 /// Eg: 3684 /// INPUT :- 3685 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ..., 3686 /// %argN_before = %N) { 3687 /// ... 3688 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ... 3689 /// } do { 3690 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after): 3691 /// ... 3692 /// some_func(%arg1_after) 3693 /// ... 3694 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after 3695 /// } 3696 /// 3697 /// OUTPUT :- 3698 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) { 3699 /// ... 3700 /// scf.condition(%cond) %arg0, %arg1, ..., %argM 3701 /// } do { 3702 /// ^bb0(%arg0, %arg3, ..., %argM): 3703 /// ... 3704 /// some_func(%a) 3705 /// ... 3706 /// scf.yield %arg0, %b, ..., %argN 3707 /// } 3708 /// 3709 /// EXPLANATION: 3710 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the 3711 /// before block of scf.while, so they get removed. 3712 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are 3713 /// replaced by %b. 3714 /// 3. The corresponding after block argument %arg1_after's uses are 3715 /// replaced by %a and %arg2_after's uses are replaced by %b. 3716 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> { 3717 using OpRewritePattern<WhileOp>::OpRewritePattern; 3718 3719 LogicalResult matchAndRewrite(WhileOp op, 3720 PatternRewriter &rewriter) const override { 3721 Block &beforeBlock = *op.getBeforeBody(); 3722 ConditionOp condOp = op.getConditionOp(); 3723 OperandRange condOpArgs = condOp.getArgs(); 3724 3725 bool canSimplify = false; 3726 for (Value condOpArg : condOpArgs) { 3727 // Those values not defined within `before` block will be considered as 3728 // loop invariant values. We map the corresponding `index` with their 3729 // value. 3730 if (condOpArg.getParentBlock() != &beforeBlock) { 3731 canSimplify = true; 3732 break; 3733 } 3734 } 3735 3736 if (!canSimplify) 3737 return failure(); 3738 3739 Block::BlockArgListType afterBlockArgs = op.getAfterArguments(); 3740 3741 SmallVector<Value> newCondOpArgs; 3742 SmallVector<Type> newAfterBlockType; 3743 DenseMap<unsigned, Value> condOpInitValMap; 3744 SmallVector<Location> newAfterBlockArgLocs; 3745 for (const auto &it : llvm::enumerate(condOpArgs)) { 3746 auto index = static_cast<unsigned>(it.index()); 3747 Value condOpArg = it.value(); 3748 // Those values not defined within `before` block will be considered as 3749 // loop invariant values. We map the corresponding `index` with their 3750 // value. 3751 if (condOpArg.getParentBlock() != &beforeBlock) { 3752 condOpInitValMap.insert({index, condOpArg}); 3753 } else { 3754 newCondOpArgs.emplace_back(condOpArg); 3755 newAfterBlockType.emplace_back(condOpArg.getType()); 3756 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc()); 3757 } 3758 } 3759 3760 { 3761 OpBuilder::InsertionGuard g(rewriter); 3762 rewriter.setInsertionPoint(condOp); 3763 rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(), 3764 newCondOpArgs); 3765 } 3766 3767 auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType, 3768 op.getOperands()); 3769 3770 Block &newAfterBlock = 3771 *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {}, 3772 newAfterBlockType, newAfterBlockArgLocs); 3773 3774 Block &afterBlock = *op.getAfterBody(); 3775 // Since a new scf.condition op was created, we need to fetch the new 3776 // `after` block arguments which will be used while replacing operations of 3777 // previous scf.while's `after` blocks. We'd also be fetching new result 3778 // values too. 3779 SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments()); 3780 SmallVector<Value> newWhileResults(afterBlock.getNumArguments()); 3781 for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) { 3782 Value afterBlockArg, result; 3783 // If index 'i' argument was loop invariant we fetch it's value from the 3784 // `condOpInitMap` map. 3785 if (condOpInitValMap.count(i) != 0) { 3786 afterBlockArg = condOpInitValMap[i]; 3787 result = afterBlockArg; 3788 } else { 3789 afterBlockArg = newAfterBlock.getArgument(j); 3790 result = newWhile.getResult(j); 3791 j++; 3792 } 3793 newAfterBlockArgs[i] = afterBlockArg; 3794 newWhileResults[i] = result; 3795 } 3796 3797 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); 3798 rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), 3799 newWhile.getBefore().begin()); 3800 3801 rewriter.replaceOp(op, newWhileResults); 3802 return success(); 3803 } 3804 }; 3805 3806 /// Remove WhileOp results that are also unused in 'after' block. 3807 /// 3808 /// %0:2 = scf.while () : () -> (i32, i64) { 3809 /// %condition = "test.condition"() : () -> i1 3810 /// %v1 = "test.get_some_value"() : () -> i32 3811 /// %v2 = "test.get_some_value"() : () -> i64 3812 /// scf.condition(%condition) %v1, %v2 : i32, i64 3813 /// } do { 3814 /// ^bb0(%arg0: i32, %arg1: i64): 3815 /// "test.use"(%arg0) : (i32) -> () 3816 /// scf.yield 3817 /// } 3818 /// return %0#0 : i32 3819 /// 3820 /// becomes 3821 /// %0 = scf.while () : () -> (i32) { 3822 /// %condition = "test.condition"() : () -> i1 3823 /// %v1 = "test.get_some_value"() : () -> i32 3824 /// %v2 = "test.get_some_value"() : () -> i64 3825 /// scf.condition(%condition) %v1 : i32 3826 /// } do { 3827 /// ^bb0(%arg0: i32): 3828 /// "test.use"(%arg0) : (i32) -> () 3829 /// scf.yield 3830 /// } 3831 /// return %0 : i32 3832 struct WhileUnusedResult : public OpRewritePattern<WhileOp> { 3833 using OpRewritePattern<WhileOp>::OpRewritePattern; 3834 3835 LogicalResult matchAndRewrite(WhileOp op, 3836 PatternRewriter &rewriter) const override { 3837 auto term = op.getConditionOp(); 3838 auto afterArgs = op.getAfterArguments(); 3839 auto termArgs = term.getArgs(); 3840 3841 // Collect results mapping, new terminator args and new result types. 3842 SmallVector<unsigned> newResultsIndices; 3843 SmallVector<Type> newResultTypes; 3844 SmallVector<Value> newTermArgs; 3845 SmallVector<Location> newArgLocs; 3846 bool needUpdate = false; 3847 for (const auto &it : 3848 llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) { 3849 auto i = static_cast<unsigned>(it.index()); 3850 Value result = std::get<0>(it.value()); 3851 Value afterArg = std::get<1>(it.value()); 3852 Value termArg = std::get<2>(it.value()); 3853 if (result.use_empty() && afterArg.use_empty()) { 3854 needUpdate = true; 3855 } else { 3856 newResultsIndices.emplace_back(i); 3857 newTermArgs.emplace_back(termArg); 3858 newResultTypes.emplace_back(result.getType()); 3859 newArgLocs.emplace_back(result.getLoc()); 3860 } 3861 } 3862 3863 if (!needUpdate) 3864 return failure(); 3865 3866 { 3867 OpBuilder::InsertionGuard g(rewriter); 3868 rewriter.setInsertionPoint(term); 3869 rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(), 3870 newTermArgs); 3871 } 3872 3873 auto newWhile = 3874 rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits()); 3875 3876 Block &newAfterBlock = *rewriter.createBlock( 3877 &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs); 3878 3879 // Build new results list and new after block args (unused entries will be 3880 // null). 3881 SmallVector<Value> newResults(op.getNumResults()); 3882 SmallVector<Value> newAfterBlockArgs(op.getNumResults()); 3883 for (const auto &it : llvm::enumerate(newResultsIndices)) { 3884 newResults[it.value()] = newWhile.getResult(it.index()); 3885 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index()); 3886 } 3887 3888 rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), 3889 newWhile.getBefore().begin()); 3890 3891 Block &afterBlock = *op.getAfterBody(); 3892 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); 3893 3894 rewriter.replaceOp(op, newResults); 3895 return success(); 3896 } 3897 }; 3898 3899 /// Replace operations equivalent to the condition in the do block with true, 3900 /// since otherwise the block would not be evaluated. 3901 /// 3902 /// scf.while (..) : (i32, ...) -> ... { 3903 /// %z = ... : i32 3904 /// %condition = cmpi pred %z, %a 3905 /// scf.condition(%condition) %z : i32, ... 3906 /// } do { 3907 /// ^bb0(%arg0: i32, ...): 3908 /// %condition2 = cmpi pred %arg0, %a 3909 /// use(%condition2) 3910 /// ... 3911 /// 3912 /// becomes 3913 /// scf.while (..) : (i32, ...) -> ... { 3914 /// %z = ... : i32 3915 /// %condition = cmpi pred %z, %a 3916 /// scf.condition(%condition) %z : i32, ... 3917 /// } do { 3918 /// ^bb0(%arg0: i32, ...): 3919 /// use(%true) 3920 /// ... 3921 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> { 3922 using OpRewritePattern<scf::WhileOp>::OpRewritePattern; 3923 3924 LogicalResult matchAndRewrite(scf::WhileOp op, 3925 PatternRewriter &rewriter) const override { 3926 using namespace scf; 3927 auto cond = op.getConditionOp(); 3928 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>(); 3929 if (!cmp) 3930 return failure(); 3931 bool changed = false; 3932 for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) { 3933 for (size_t opIdx = 0; opIdx < 2; opIdx++) { 3934 if (std::get<0>(tup) != cmp.getOperand(opIdx)) 3935 continue; 3936 for (OpOperand &u : 3937 llvm::make_early_inc_range(std::get<1>(tup).getUses())) { 3938 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner()); 3939 if (!cmp2) 3940 continue; 3941 // For a binary operator 1-opIdx gets the other side. 3942 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx)) 3943 continue; 3944 bool samePredicate; 3945 if (cmp2.getPredicate() == cmp.getPredicate()) 3946 samePredicate = true; 3947 else if (cmp2.getPredicate() == 3948 arith::invertPredicate(cmp.getPredicate())) 3949 samePredicate = false; 3950 else 3951 continue; 3952 3953 rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate, 3954 1); 3955 changed = true; 3956 } 3957 } 3958 } 3959 return success(changed); 3960 } 3961 }; 3962 3963 /// Remove unused init/yield args. 3964 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> { 3965 using OpRewritePattern<WhileOp>::OpRewritePattern; 3966 3967 LogicalResult matchAndRewrite(WhileOp op, 3968 PatternRewriter &rewriter) const override { 3969 3970 if (!llvm::any_of(op.getBeforeArguments(), 3971 [](Value arg) { return arg.use_empty(); })) 3972 return rewriter.notifyMatchFailure(op, "No args to remove"); 3973 3974 YieldOp yield = op.getYieldOp(); 3975 3976 // Collect results mapping, new terminator args and new result types. 3977 SmallVector<Value> newYields; 3978 SmallVector<Value> newInits; 3979 llvm::BitVector argsToErase; 3980 3981 size_t argsCount = op.getBeforeArguments().size(); 3982 newYields.reserve(argsCount); 3983 newInits.reserve(argsCount); 3984 argsToErase.reserve(argsCount); 3985 for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip( 3986 op.getBeforeArguments(), yield.getOperands(), op.getInits())) { 3987 if (beforeArg.use_empty()) { 3988 argsToErase.push_back(true); 3989 } else { 3990 argsToErase.push_back(false); 3991 newYields.emplace_back(yieldValue); 3992 newInits.emplace_back(initValue); 3993 } 3994 } 3995 3996 Block &beforeBlock = *op.getBeforeBody(); 3997 Block &afterBlock = *op.getAfterBody(); 3998 3999 beforeBlock.eraseArguments(argsToErase); 4000 4001 Location loc = op.getLoc(); 4002 auto newWhileOp = 4003 rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits, 4004 /*beforeBody*/ nullptr, /*afterBody*/ nullptr); 4005 Block &newBeforeBlock = *newWhileOp.getBeforeBody(); 4006 Block &newAfterBlock = *newWhileOp.getAfterBody(); 4007 4008 OpBuilder::InsertionGuard g(rewriter); 4009 rewriter.setInsertionPoint(yield); 4010 rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields); 4011 4012 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, 4013 newBeforeBlock.getArguments()); 4014 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, 4015 newAfterBlock.getArguments()); 4016 4017 rewriter.replaceOp(op, newWhileOp.getResults()); 4018 return success(); 4019 } 4020 }; 4021 4022 /// Remove duplicated ConditionOp args. 4023 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> { 4024 using OpRewritePattern::OpRewritePattern; 4025 4026 LogicalResult matchAndRewrite(WhileOp op, 4027 PatternRewriter &rewriter) const override { 4028 ConditionOp condOp = op.getConditionOp(); 4029 ValueRange condOpArgs = condOp.getArgs(); 4030 4031 llvm::SmallPtrSet<Value, 8> argsSet; 4032 for (Value arg : condOpArgs) 4033 argsSet.insert(arg); 4034 4035 if (argsSet.size() == condOpArgs.size()) 4036 return rewriter.notifyMatchFailure(op, "No results to remove"); 4037 4038 llvm::SmallDenseMap<Value, unsigned> argsMap; 4039 SmallVector<Value> newArgs; 4040 argsMap.reserve(condOpArgs.size()); 4041 newArgs.reserve(condOpArgs.size()); 4042 for (Value arg : condOpArgs) { 4043 if (!argsMap.count(arg)) { 4044 auto pos = static_cast<unsigned>(argsMap.size()); 4045 argsMap.insert({arg, pos}); 4046 newArgs.emplace_back(arg); 4047 } 4048 } 4049 4050 ValueRange argsRange(newArgs); 4051 4052 Location loc = op.getLoc(); 4053 auto newWhileOp = rewriter.create<scf::WhileOp>( 4054 loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr, 4055 /*afterBody*/ nullptr); 4056 Block &newBeforeBlock = *newWhileOp.getBeforeBody(); 4057 Block &newAfterBlock = *newWhileOp.getAfterBody(); 4058 4059 SmallVector<Value> afterArgsMapping; 4060 SmallVector<Value> resultsMapping; 4061 for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) { 4062 auto it = argsMap.find(arg); 4063 assert(it != argsMap.end()); 4064 auto pos = it->second; 4065 afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos)); 4066 resultsMapping.emplace_back(newWhileOp->getResult(pos)); 4067 } 4068 4069 OpBuilder::InsertionGuard g(rewriter); 4070 rewriter.setInsertionPoint(condOp); 4071 rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(), 4072 argsRange); 4073 4074 Block &beforeBlock = *op.getBeforeBody(); 4075 Block &afterBlock = *op.getAfterBody(); 4076 4077 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, 4078 newBeforeBlock.getArguments()); 4079 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping); 4080 rewriter.replaceOp(op, resultsMapping); 4081 return success(); 4082 } 4083 }; 4084 4085 /// If both ranges contain same values return mappping indices from args2 to 4086 /// args1. Otherwise return std::nullopt. 4087 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1, 4088 ValueRange args2) { 4089 if (args1.size() != args2.size()) 4090 return std::nullopt; 4091 4092 SmallVector<unsigned> ret(args1.size()); 4093 for (auto &&[i, arg1] : llvm::enumerate(args1)) { 4094 auto it = llvm::find(args2, arg1); 4095 if (it == args2.end()) 4096 return std::nullopt; 4097 4098 ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i); 4099 } 4100 4101 return ret; 4102 } 4103 4104 static bool hasDuplicates(ValueRange args) { 4105 llvm::SmallDenseSet<Value> set; 4106 for (Value arg : args) { 4107 if (!set.insert(arg).second) 4108 return true; 4109 } 4110 return false; 4111 } 4112 4113 /// If `before` block args are directly forwarded to `scf.condition`, rearrange 4114 /// `scf.condition` args into same order as block args. Update `after` block 4115 /// args and op result values accordingly. 4116 /// Needed to simplify `scf.while` -> `scf.for` uplifting. 4117 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { 4118 using OpRewritePattern::OpRewritePattern; 4119 4120 LogicalResult matchAndRewrite(WhileOp loop, 4121 PatternRewriter &rewriter) const override { 4122 auto oldBefore = loop.getBeforeBody(); 4123 ConditionOp oldTerm = loop.getConditionOp(); 4124 ValueRange beforeArgs = oldBefore->getArguments(); 4125 ValueRange termArgs = oldTerm.getArgs(); 4126 if (beforeArgs == termArgs) 4127 return failure(); 4128 4129 if (hasDuplicates(termArgs)) 4130 return failure(); 4131 4132 auto mapping = getArgsMapping(beforeArgs, termArgs); 4133 if (!mapping) 4134 return failure(); 4135 4136 { 4137 OpBuilder::InsertionGuard g(rewriter); 4138 rewriter.setInsertionPoint(oldTerm); 4139 rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(), 4140 beforeArgs); 4141 } 4142 4143 auto oldAfter = loop.getAfterBody(); 4144 4145 SmallVector<Type> newResultTypes(beforeArgs.size()); 4146 for (auto &&[i, j] : llvm::enumerate(*mapping)) 4147 newResultTypes[j] = loop.getResult(i).getType(); 4148 4149 auto newLoop = rewriter.create<WhileOp>( 4150 loop.getLoc(), newResultTypes, loop.getInits(), 4151 /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr); 4152 auto newBefore = newLoop.getBeforeBody(); 4153 auto newAfter = newLoop.getAfterBody(); 4154 4155 SmallVector<Value> newResults(beforeArgs.size()); 4156 SmallVector<Value> newAfterArgs(beforeArgs.size()); 4157 for (auto &&[i, j] : llvm::enumerate(*mapping)) { 4158 newResults[i] = newLoop.getResult(j); 4159 newAfterArgs[i] = newAfter->getArgument(j); 4160 } 4161 4162 rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(), 4163 newBefore->getArguments()); 4164 rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(), 4165 newAfterArgs); 4166 4167 rewriter.replaceOp(loop, newResults); 4168 return success(); 4169 } 4170 }; 4171 } // namespace 4172 4173 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, 4174 MLIRContext *context) { 4175 results.add<RemoveLoopInvariantArgsFromBeforeBlock, 4176 RemoveLoopInvariantValueYielded, WhileConditionTruth, 4177 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, 4178 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context); 4179 } 4180 4181 //===----------------------------------------------------------------------===// 4182 // IndexSwitchOp 4183 //===----------------------------------------------------------------------===// 4184 4185 /// Parse the case regions and values. 4186 static ParseResult 4187 parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, 4188 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) { 4189 SmallVector<int64_t> caseValues; 4190 while (succeeded(p.parseOptionalKeyword("case"))) { 4191 int64_t value; 4192 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>()); 4193 if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{})) 4194 return failure(); 4195 caseValues.push_back(value); 4196 } 4197 cases = p.getBuilder().getDenseI64ArrayAttr(caseValues); 4198 return success(); 4199 } 4200 4201 /// Print the case regions and values. 4202 static void printSwitchCases(OpAsmPrinter &p, Operation *op, 4203 DenseI64ArrayAttr cases, RegionRange caseRegions) { 4204 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { 4205 p.printNewline(); 4206 p << "case " << value << ' '; 4207 p.printRegion(*region, /*printEntryBlockArgs=*/false); 4208 } 4209 } 4210 4211 LogicalResult scf::IndexSwitchOp::verify() { 4212 if (getCases().size() != getCaseRegions().size()) { 4213 return emitOpError("has ") 4214 << getCaseRegions().size() << " case regions but " 4215 << getCases().size() << " case values"; 4216 } 4217 4218 DenseSet<int64_t> valueSet; 4219 for (int64_t value : getCases()) 4220 if (!valueSet.insert(value).second) 4221 return emitOpError("has duplicate case value: ") << value; 4222 auto verifyRegion = [&](Region ®ion, const Twine &name) -> LogicalResult { 4223 auto yield = dyn_cast<YieldOp>(region.front().back()); 4224 if (!yield) 4225 return emitOpError("expected region to end with scf.yield, but got ") 4226 << region.front().back().getName(); 4227 4228 if (yield.getNumOperands() != getNumResults()) { 4229 return (emitOpError("expected each region to return ") 4230 << getNumResults() << " values, but " << name << " returns " 4231 << yield.getNumOperands()) 4232 .attachNote(yield.getLoc()) 4233 << "see yield operation here"; 4234 } 4235 for (auto [idx, result, operand] : 4236 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(), 4237 yield.getOperandTypes())) { 4238 if (result == operand) 4239 continue; 4240 return (emitOpError("expected result #") 4241 << idx << " of each region to be " << result) 4242 .attachNote(yield.getLoc()) 4243 << name << " returns " << operand << " here"; 4244 } 4245 return success(); 4246 }; 4247 4248 if (failed(verifyRegion(getDefaultRegion(), "default region"))) 4249 return failure(); 4250 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions())) 4251 if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx)))) 4252 return failure(); 4253 4254 return success(); 4255 } 4256 4257 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); } 4258 4259 Block &scf::IndexSwitchOp::getDefaultBlock() { 4260 return getDefaultRegion().front(); 4261 } 4262 4263 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) { 4264 assert(idx < getNumCases() && "case index out-of-bounds"); 4265 return getCaseRegions()[idx].front(); 4266 } 4267 4268 void IndexSwitchOp::getSuccessorRegions( 4269 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) { 4270 // All regions branch back to the parent op. 4271 if (!point.isParent()) { 4272 successors.emplace_back(getResults()); 4273 return; 4274 } 4275 4276 llvm::copy(getRegions(), std::back_inserter(successors)); 4277 } 4278 4279 void IndexSwitchOp::getEntrySuccessorRegions( 4280 ArrayRef<Attribute> operands, 4281 SmallVectorImpl<RegionSuccessor> &successors) { 4282 FoldAdaptor adaptor(operands, *this); 4283 4284 // If a constant was not provided, all regions are possible successors. 4285 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg()); 4286 if (!arg) { 4287 llvm::copy(getRegions(), std::back_inserter(successors)); 4288 return; 4289 } 4290 4291 // Otherwise, try to find a case with a matching value. If not, the 4292 // default region is the only successor. 4293 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) { 4294 if (caseValue == arg.getInt()) { 4295 successors.emplace_back(&caseRegion); 4296 return; 4297 } 4298 } 4299 successors.emplace_back(&getDefaultRegion()); 4300 } 4301 4302 void IndexSwitchOp::getRegionInvocationBounds( 4303 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 4304 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front()); 4305 if (!operandValue) { 4306 // All regions are invoked at most once. 4307 bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1)); 4308 return; 4309 } 4310 4311 unsigned liveIndex = getNumRegions() - 1; 4312 const auto *it = llvm::find(getCases(), operandValue.getInt()); 4313 if (it != getCases().end()) 4314 liveIndex = std::distance(getCases().begin(), it); 4315 for (unsigned i = 0, e = getNumRegions(); i < e; ++i) 4316 bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex); 4317 } 4318 4319 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { 4320 using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern; 4321 4322 LogicalResult matchAndRewrite(scf::IndexSwitchOp op, 4323 PatternRewriter &rewriter) const override { 4324 // If `op.getArg()` is a constant, select the region that matches with 4325 // the constant value. Use the default region if no matche is found. 4326 std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg()); 4327 if (!maybeCst.has_value()) 4328 return failure(); 4329 int64_t cst = *maybeCst; 4330 int64_t caseIdx, e = op.getNumCases(); 4331 for (caseIdx = 0; caseIdx < e; ++caseIdx) { 4332 if (cst == op.getCases()[caseIdx]) 4333 break; 4334 } 4335 4336 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx] 4337 : op.getDefaultRegion(); 4338 Block &source = r.front(); 4339 Operation *terminator = source.getTerminator(); 4340 SmallVector<Value> results = terminator->getOperands(); 4341 4342 rewriter.inlineBlockBefore(&source, op); 4343 rewriter.eraseOp(terminator); 4344 // Replace the operation with a potentially empty list of results. 4345 // Fold mechanism doesn't support the case where the result list is empty. 4346 rewriter.replaceOp(op, results); 4347 4348 return success(); 4349 } 4350 }; 4351 4352 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, 4353 MLIRContext *context) { 4354 results.add<FoldConstantCase>(context); 4355 } 4356 4357 //===----------------------------------------------------------------------===// 4358 // TableGen'd op method definitions 4359 //===----------------------------------------------------------------------===// 4360 4361 #define GET_OP_CLASSES 4362 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc" 4363