1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" 10 11 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Affine/LoopUtils.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Arith/Utils/Utils.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/SCF/IR/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 19 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 20 #include "mlir/Dialect/SCF/Utils/Utils.h" 21 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 22 #include "mlir/Dialect/Transform/IR/TransformOps.h" 23 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 24 #include "mlir/Dialect/Utils/StaticValueUtils.h" 25 #include "mlir/Dialect/Vector/IR/VectorOps.h" 26 #include "mlir/IR/BuiltinAttributes.h" 27 #include "mlir/IR/Dominance.h" 28 #include "mlir/IR/OpDefinition.h" 29 30 using namespace mlir; 31 using namespace mlir::affine; 32 33 //===----------------------------------------------------------------------===// 34 // Apply...PatternsOp 35 //===----------------------------------------------------------------------===// 36 37 void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns( 38 RewritePatternSet &patterns) { 39 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 40 } 41 42 void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns( 43 TypeConverter &typeConverter, RewritePatternSet &patterns) { 44 scf::populateSCFStructuralTypeConversions(typeConverter, patterns); 45 } 46 47 void transform::ApplySCFStructuralConversionPatternsOp:: 48 populateConversionTargetRules(const TypeConverter &typeConverter, 49 ConversionTarget &conversionTarget) { 50 scf::populateSCFStructuralTypeConversionTarget(typeConverter, 51 conversionTarget); 52 } 53 54 void transform::ApplySCFToControlFlowPatternsOp::populatePatterns( 55 TypeConverter &typeConverter, RewritePatternSet &patterns) { 56 populateSCFToControlFlowConversionPatterns(patterns); 57 } 58 59 //===----------------------------------------------------------------------===// 60 // ForallToForOp 61 //===----------------------------------------------------------------------===// 62 63 DiagnosedSilenceableFailure 64 transform::ForallToForOp::apply(transform::TransformRewriter &rewriter, 65 transform::TransformResults &results, 66 transform::TransformState &state) { 67 auto payload = state.getPayloadOps(getTarget()); 68 if (!llvm::hasSingleElement(payload)) 69 return emitSilenceableError() << "expected a single payload op"; 70 71 auto target = dyn_cast<scf::ForallOp>(*payload.begin()); 72 if (!target) { 73 DiagnosedSilenceableFailure diag = 74 emitSilenceableError() << "expected the payload to be scf.forall"; 75 diag.attachNote((*payload.begin())->getLoc()) << "payload op"; 76 return diag; 77 } 78 79 if (!target.getOutputs().empty()) { 80 return emitSilenceableError() 81 << "unsupported shared outputs (didn't bufferize?)"; 82 } 83 84 SmallVector<OpFoldResult> lbs = target.getMixedLowerBound(); 85 86 if (getNumResults() != lbs.size()) { 87 DiagnosedSilenceableFailure diag = 88 emitSilenceableError() 89 << "op expects as many results (" << getNumResults() 90 << ") as payload has induction variables (" << lbs.size() << ")"; 91 diag.attachNote(target.getLoc()) << "payload op"; 92 return diag; 93 } 94 95 SmallVector<Operation *> opResults; 96 if (failed(scf::forallToForLoop(rewriter, target, &opResults))) { 97 DiagnosedSilenceableFailure diag = emitSilenceableError() 98 << "failed to convert forall into for"; 99 return diag; 100 } 101 102 for (auto &&[i, res] : llvm::enumerate(opResults)) { 103 results.set(cast<OpResult>(getTransformed()[i]), {res}); 104 } 105 return DiagnosedSilenceableFailure::success(); 106 } 107 108 //===----------------------------------------------------------------------===// 109 // ForallToForOp 110 //===----------------------------------------------------------------------===// 111 112 DiagnosedSilenceableFailure 113 transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter, 114 transform::TransformResults &results, 115 transform::TransformState &state) { 116 auto payload = state.getPayloadOps(getTarget()); 117 if (!llvm::hasSingleElement(payload)) 118 return emitSilenceableError() << "expected a single payload op"; 119 120 auto target = dyn_cast<scf::ForallOp>(*payload.begin()); 121 if (!target) { 122 DiagnosedSilenceableFailure diag = 123 emitSilenceableError() << "expected the payload to be scf.forall"; 124 diag.attachNote((*payload.begin())->getLoc()) << "payload op"; 125 return diag; 126 } 127 128 if (!target.getOutputs().empty()) { 129 return emitSilenceableError() 130 << "unsupported shared outputs (didn't bufferize?)"; 131 } 132 133 if (getNumResults() != 1) { 134 DiagnosedSilenceableFailure diag = emitSilenceableError() 135 << "op expects one result, given " 136 << getNumResults(); 137 diag.attachNote(target.getLoc()) << "payload op"; 138 return diag; 139 } 140 141 scf::ParallelOp opResult; 142 if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) { 143 DiagnosedSilenceableFailure diag = 144 emitSilenceableError() << "failed to convert forall into parallel"; 145 return diag; 146 } 147 148 results.set(cast<OpResult>(getTransformed()[0]), {opResult}); 149 return DiagnosedSilenceableFailure::success(); 150 } 151 152 //===----------------------------------------------------------------------===// 153 // LoopOutlineOp 154 //===----------------------------------------------------------------------===// 155 156 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses 157 /// the provided rewriter for all operations to remain compatible with the 158 /// rewriting infra, as opposed to just splicing the op in place. 159 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, 160 Operation *op) { 161 if (op->getNumRegions() != 1) 162 return nullptr; 163 OpBuilder::InsertionGuard g(b); 164 b.setInsertionPoint(op); 165 scf::ExecuteRegionOp executeRegionOp = 166 b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes()); 167 { 168 OpBuilder::InsertionGuard g(b); 169 b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); 170 Operation *clonedOp = b.cloneWithoutRegions(*op); 171 Region &clonedRegion = clonedOp->getRegions().front(); 172 assert(clonedRegion.empty() && "expected empty region"); 173 b.inlineRegionBefore(op->getRegions().front(), clonedRegion, 174 clonedRegion.end()); 175 b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults()); 176 } 177 b.replaceOp(op, executeRegionOp.getResults()); 178 return executeRegionOp; 179 } 180 181 DiagnosedSilenceableFailure 182 transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter, 183 transform::TransformResults &results, 184 transform::TransformState &state) { 185 SmallVector<Operation *> functions; 186 SmallVector<Operation *> calls; 187 DenseMap<Operation *, SymbolTable> symbolTables; 188 for (Operation *target : state.getPayloadOps(getTarget())) { 189 Location location = target->getLoc(); 190 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); 191 scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); 192 if (!exec) { 193 DiagnosedSilenceableFailure diag = emitSilenceableError() 194 << "failed to outline"; 195 diag.attachNote(target->getLoc()) << "target op"; 196 return diag; 197 } 198 func::CallOp call; 199 FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion( 200 rewriter, location, exec.getRegion(), getFuncName(), &call); 201 202 if (failed(outlined)) 203 return emitDefaultDefiniteFailure(target); 204 205 if (symbolTableOp) { 206 SymbolTable &symbolTable = 207 symbolTables.try_emplace(symbolTableOp, symbolTableOp) 208 .first->getSecond(); 209 symbolTable.insert(*outlined); 210 call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); 211 } 212 functions.push_back(*outlined); 213 calls.push_back(call); 214 } 215 results.set(cast<OpResult>(getFunction()), functions); 216 results.set(cast<OpResult>(getCall()), calls); 217 return DiagnosedSilenceableFailure::success(); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // LoopPeelOp 222 //===----------------------------------------------------------------------===// 223 224 DiagnosedSilenceableFailure 225 transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter, 226 scf::ForOp target, 227 transform::ApplyToEachResultList &results, 228 transform::TransformState &state) { 229 scf::ForOp result; 230 if (getPeelFront()) { 231 LogicalResult status = 232 scf::peelForLoopFirstIteration(rewriter, target, result); 233 if (failed(status)) { 234 DiagnosedSilenceableFailure diag = 235 emitSilenceableError() << "failed to peel the first iteration"; 236 return diag; 237 } 238 } else { 239 LogicalResult status = 240 scf::peelForLoopAndSimplifyBounds(rewriter, target, result); 241 if (failed(status)) { 242 DiagnosedSilenceableFailure diag = emitSilenceableError() 243 << "failed to peel the last iteration"; 244 return diag; 245 } 246 } 247 248 results.push_back(target); 249 results.push_back(result); 250 251 return DiagnosedSilenceableFailure::success(); 252 } 253 254 //===----------------------------------------------------------------------===// 255 // LoopPipelineOp 256 //===----------------------------------------------------------------------===// 257 258 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an 259 /// operation to its logical time position given the iteration interval and the 260 /// read latency. The latter is only relevant for vector transfers. 261 static void 262 loopScheduling(scf::ForOp forOp, 263 std::vector<std::pair<Operation *, unsigned>> &schedule, 264 unsigned iterationInterval, unsigned readLatency) { 265 auto getLatency = [&](Operation *op) -> unsigned { 266 if (isa<vector::TransferReadOp>(op)) 267 return readLatency; 268 return 1; 269 }; 270 271 std::optional<int64_t> ubConstant = 272 getConstantIntValue(forOp.getUpperBound()); 273 std::optional<int64_t> lbConstant = 274 getConstantIntValue(forOp.getLowerBound()); 275 DenseMap<Operation *, unsigned> opCycles; 276 std::map<unsigned, std::vector<Operation *>> wrappedSchedule; 277 for (Operation &op : forOp.getBody()->getOperations()) { 278 if (isa<scf::YieldOp>(op)) 279 continue; 280 unsigned earlyCycle = 0; 281 for (Value operand : op.getOperands()) { 282 Operation *def = operand.getDefiningOp(); 283 if (!def) 284 continue; 285 if (ubConstant && lbConstant) { 286 unsigned ubInt = ubConstant.value(); 287 unsigned lbInt = lbConstant.value(); 288 auto minLatency = std::min(ubInt - lbInt - 1, getLatency(def)); 289 earlyCycle = std::max(earlyCycle, opCycles[def] + minLatency); 290 } else { 291 earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); 292 } 293 } 294 opCycles[&op] = earlyCycle; 295 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op); 296 } 297 for (const auto &it : wrappedSchedule) { 298 for (Operation *op : it.second) { 299 unsigned cycle = opCycles[op]; 300 schedule.emplace_back(op, cycle / iterationInterval); 301 } 302 } 303 } 304 305 DiagnosedSilenceableFailure 306 transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter, 307 scf::ForOp target, 308 transform::ApplyToEachResultList &results, 309 transform::TransformState &state) { 310 scf::PipeliningOption options; 311 options.getScheduleFn = 312 [this](scf::ForOp forOp, 313 std::vector<std::pair<Operation *, unsigned>> &schedule) mutable { 314 loopScheduling(forOp, schedule, getIterationInterval(), 315 getReadLatency()); 316 }; 317 scf::ForLoopPipeliningPattern pattern(options, target->getContext()); 318 rewriter.setInsertionPoint(target); 319 FailureOr<scf::ForOp> patternResult = 320 scf::pipelineForLoop(rewriter, target, options); 321 if (succeeded(patternResult)) { 322 results.push_back(*patternResult); 323 return DiagnosedSilenceableFailure::success(); 324 } 325 return emitDefaultSilenceableFailure(target); 326 } 327 328 //===----------------------------------------------------------------------===// 329 // LoopPromoteIfOneIterationOp 330 //===----------------------------------------------------------------------===// 331 332 DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne( 333 transform::TransformRewriter &rewriter, LoopLikeOpInterface target, 334 transform::ApplyToEachResultList &results, 335 transform::TransformState &state) { 336 (void)target.promoteIfSingleIteration(rewriter); 337 return DiagnosedSilenceableFailure::success(); 338 } 339 340 void transform::LoopPromoteIfOneIterationOp::getEffects( 341 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 342 consumesHandle(getTargetMutable(), effects); 343 modifiesPayload(effects); 344 } 345 346 //===----------------------------------------------------------------------===// 347 // LoopUnrollOp 348 //===----------------------------------------------------------------------===// 349 350 DiagnosedSilenceableFailure 351 transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, 352 Operation *op, 353 transform::ApplyToEachResultList &results, 354 transform::TransformState &state) { 355 LogicalResult result(failure()); 356 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) 357 result = loopUnrollByFactor(scfFor, getFactor()); 358 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op)) 359 result = loopUnrollByFactor(affineFor, getFactor()); 360 else 361 return emitSilenceableError() 362 << "failed to unroll, incorrect type of payload"; 363 364 if (failed(result)) 365 return emitSilenceableError() << "failed to unroll"; 366 367 return DiagnosedSilenceableFailure::success(); 368 } 369 370 //===----------------------------------------------------------------------===// 371 // LoopUnrollAndJamOp 372 //===----------------------------------------------------------------------===// 373 374 DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne( 375 transform::TransformRewriter &rewriter, Operation *op, 376 transform::ApplyToEachResultList &results, 377 transform::TransformState &state) { 378 LogicalResult result(failure()); 379 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) 380 result = loopUnrollJamByFactor(scfFor, getFactor()); 381 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op)) 382 result = loopUnrollJamByFactor(affineFor, getFactor()); 383 else 384 return emitSilenceableError() 385 << "failed to unroll and jam, incorrect type of payload"; 386 387 if (failed(result)) 388 return emitSilenceableError() << "failed to unroll and jam"; 389 390 return DiagnosedSilenceableFailure::success(); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // LoopCoalesceOp 395 //===----------------------------------------------------------------------===// 396 397 DiagnosedSilenceableFailure 398 transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter, 399 Operation *op, 400 transform::ApplyToEachResultList &results, 401 transform::TransformState &state) { 402 LogicalResult result(failure()); 403 if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op)) 404 result = coalescePerfectlyNestedSCFForLoops(scfForOp); 405 else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op)) 406 result = coalescePerfectlyNestedAffineLoops(affineForOp); 407 408 results.push_back(op); 409 if (failed(result)) { 410 DiagnosedSilenceableFailure diag = emitSilenceableError() 411 << "failed to coalesce"; 412 return diag; 413 } 414 return DiagnosedSilenceableFailure::success(); 415 } 416 417 //===----------------------------------------------------------------------===// 418 // TakeAssumedBranchOp 419 //===----------------------------------------------------------------------===// 420 /// Replaces the given op with the contents of the given single-block region, 421 /// using the operands of the block terminator to replace operation results. 422 static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, 423 Region ®ion) { 424 assert(llvm::hasSingleElement(region) && "expected single-region block"); 425 Block *block = ®ion.front(); 426 Operation *terminator = block->getTerminator(); 427 ValueRange results = terminator->getOperands(); 428 rewriter.inlineBlockBefore(block, op, /*blockArgs=*/{}); 429 rewriter.replaceOp(op, results); 430 rewriter.eraseOp(terminator); 431 } 432 433 DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne( 434 transform::TransformRewriter &rewriter, scf::IfOp ifOp, 435 transform::ApplyToEachResultList &results, 436 transform::TransformState &state) { 437 rewriter.setInsertionPoint(ifOp); 438 Region ®ion = 439 getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion(); 440 if (!llvm::hasSingleElement(region)) { 441 return emitDefiniteFailure() 442 << "requires an scf.if op with a single-block " 443 << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region"; 444 } 445 replaceOpWithRegion(rewriter, ifOp, region); 446 return DiagnosedSilenceableFailure::success(); 447 } 448 449 void transform::TakeAssumedBranchOp::getEffects( 450 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 451 onlyReadsHandle(getTargetMutable(), effects); 452 modifiesPayload(effects); 453 } 454 455 //===----------------------------------------------------------------------===// 456 // LoopFuseSiblingOp 457 //===----------------------------------------------------------------------===// 458 459 /// Check if `target` and `source` are siblings, in the context that `target` 460 /// is being fused into `source`. 461 /// 462 /// This is a simple check that just checks if both operations are in the same 463 /// block and some checks to ensure that the fused IR does not violate 464 /// dominance. 465 static DiagnosedSilenceableFailure isOpSibling(Operation *target, 466 Operation *source) { 467 // Check if both operations are same. 468 if (target == source) 469 return emitSilenceableFailure(source) 470 << "target and source need to be different loops"; 471 472 // Check if both operations are in the same block. 473 if (target->getBlock() != source->getBlock()) 474 return emitSilenceableFailure(source) 475 << "target and source are not in the same block"; 476 477 // Check if fusion will violate dominance. 478 DominanceInfo domInfo(source); 479 if (target->isBeforeInBlock(source)) { 480 // Since `target` is before `source`, all users of results of `target` 481 // need to be dominated by `source`. 482 for (Operation *user : target->getUsers()) { 483 if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { 484 return emitSilenceableFailure(target) 485 << "user of results of target should be properly dominated by " 486 "source"; 487 } 488 } 489 } else { 490 // Since `target` is after `source`, all values used by `target` need 491 // to dominate `source`. 492 493 // Check if operands of `target` are dominated by `source`. 494 for (Value operand : target->getOperands()) { 495 Operation *operandOp = operand.getDefiningOp(); 496 // Operands without defining operations are block arguments. When `target` 497 // and `source` occur in the same block, these operands dominate `source`. 498 if (!operandOp) 499 continue; 500 501 // Operand's defining operation should properly dominate `source`. 502 if (!domInfo.properlyDominates(operandOp, source, 503 /*enclosingOpOk=*/false)) 504 return emitSilenceableFailure(target) 505 << "operands of target should be properly dominated by source"; 506 } 507 508 // Check if values used by `target` are dominated by `source`. 509 bool failed = false; 510 OpOperand *failedValue = nullptr; 511 visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { 512 Operation *operandOp = operand->get().getDefiningOp(); 513 if (operandOp && !domInfo.properlyDominates(operandOp, source, 514 /*enclosingOpOk=*/false)) { 515 // `operand` is not an argument of an enclosing block and the defining 516 // op of `operand` is outside `target` but does not dominate `source`. 517 failed = true; 518 failedValue = operand; 519 } 520 }); 521 522 if (failed) 523 return emitSilenceableFailure(failedValue->getOwner()) 524 << "values used inside regions of target should be properly " 525 "dominated by source"; 526 } 527 528 return DiagnosedSilenceableFailure::success(); 529 } 530 531 /// Check if `target` scf.forall can be fused into `source` scf.forall. 532 /// 533 /// This simply checks if both loops have the same bounds, steps and mapping. 534 /// No attempt is made at checking that the side effects of `target` and 535 /// `source` are independent of each other. 536 static bool isForallWithIdenticalConfiguration(Operation *target, 537 Operation *source) { 538 auto targetOp = dyn_cast<scf::ForallOp>(target); 539 auto sourceOp = dyn_cast<scf::ForallOp>(source); 540 if (!targetOp || !sourceOp) 541 return false; 542 543 return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && 544 targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && 545 targetOp.getMixedStep() == sourceOp.getMixedStep() && 546 targetOp.getMapping() == sourceOp.getMapping(); 547 } 548 549 /// Check if `target` scf.for can be fused into `source` scf.for. 550 /// 551 /// This simply checks if both loops have the same bounds and steps. No attempt 552 /// is made at checking that the side effects of `target` and `source` are 553 /// independent of each other. 554 static bool isForWithIdenticalConfiguration(Operation *target, 555 Operation *source) { 556 auto targetOp = dyn_cast<scf::ForOp>(target); 557 auto sourceOp = dyn_cast<scf::ForOp>(source); 558 if (!targetOp || !sourceOp) 559 return false; 560 561 return targetOp.getLowerBound() == sourceOp.getLowerBound() && 562 targetOp.getUpperBound() == sourceOp.getUpperBound() && 563 targetOp.getStep() == sourceOp.getStep(); 564 } 565 566 DiagnosedSilenceableFailure 567 transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, 568 transform::TransformResults &results, 569 transform::TransformState &state) { 570 auto targetOps = state.getPayloadOps(getTarget()); 571 auto sourceOps = state.getPayloadOps(getSource()); 572 573 if (!llvm::hasSingleElement(targetOps) || 574 !llvm::hasSingleElement(sourceOps)) { 575 return emitDefiniteFailure() 576 << "requires exactly one target handle (got " 577 << llvm::range_size(targetOps) << ") and exactly one " 578 << "source handle (got " << llvm::range_size(sourceOps) << ")"; 579 } 580 581 Operation *target = *targetOps.begin(); 582 Operation *source = *sourceOps.begin(); 583 584 // Check if the target and source are siblings. 585 DiagnosedSilenceableFailure diag = isOpSibling(target, source); 586 if (!diag.succeeded()) 587 return diag; 588 589 Operation *fusedLoop; 590 /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. 591 if (isForWithIdenticalConfiguration(target, source)) { 592 fusedLoop = fuseIndependentSiblingForLoops( 593 cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter); 594 } else if (isForallWithIdenticalConfiguration(target, source)) { 595 fusedLoop = fuseIndependentSiblingForallLoops( 596 cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter); 597 } else 598 return emitSilenceableFailure(target->getLoc()) 599 << "operations cannot be fused"; 600 601 assert(fusedLoop && "failed to fuse operations"); 602 603 results.set(cast<OpResult>(getFusedLoop()), {fusedLoop}); 604 return DiagnosedSilenceableFailure::success(); 605 } 606 607 //===----------------------------------------------------------------------===// 608 // Transform op registration 609 //===----------------------------------------------------------------------===// 610 611 namespace { 612 class SCFTransformDialectExtension 613 : public transform::TransformDialectExtension< 614 SCFTransformDialectExtension> { 615 public: 616 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension) 617 618 using Base::Base; 619 620 void init() { 621 declareGeneratedDialect<affine::AffineDialect>(); 622 declareGeneratedDialect<func::FuncDialect>(); 623 624 registerTransformOps< 625 #define GET_OP_LIST 626 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" 627 >(); 628 } 629 }; 630 } // namespace 631 632 #define GET_OP_CLASSES 633 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" 634 635 void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) { 636 registry.addExtensions<SCFTransformDialectExtension>(); 637 } 638