1 //===- TransformOps.cpp - Transform dialect operations --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformOps.h" 10 11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Dialect/Transform/IR/TransformAttrs.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Dialect/Transform/IR/TransformTypes.h" 17 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" 18 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 19 #include "mlir/IR/BuiltinAttributes.h" 20 #include "mlir/IR/Diagnostics.h" 21 #include "mlir/IR/Dominance.h" 22 #include "mlir/IR/OpImplementation.h" 23 #include "mlir/IR/OperationSupport.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/IR/Verifier.h" 26 #include "mlir/Interfaces/CallInterfaces.h" 27 #include "mlir/Interfaces/ControlFlowInterfaces.h" 28 #include "mlir/Interfaces/FunctionImplementation.h" 29 #include "mlir/Interfaces/FunctionInterfaces.h" 30 #include "mlir/Pass/Pass.h" 31 #include "mlir/Pass/PassManager.h" 32 #include "mlir/Pass/PassRegistry.h" 33 #include "mlir/Transforms/CSE.h" 34 #include "mlir/Transforms/DialectConversion.h" 35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 36 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" 37 #include "llvm/ADT/DenseSet.h" 38 #include "llvm/ADT/STLExtras.h" 39 #include "llvm/ADT/ScopeExit.h" 40 #include "llvm/ADT/SmallPtrSet.h" 41 #include "llvm/ADT/TypeSwitch.h" 42 #include "llvm/Support/Debug.h" 43 #include "llvm/Support/ErrorHandling.h" 44 #include <optional> 45 46 #define DEBUG_TYPE "transform-dialect" 47 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 48 49 #define DEBUG_TYPE_MATCHER "transform-matcher" 50 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") 51 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) 52 53 using namespace mlir; 54 55 static ParseResult parseSequenceOpOperands( 56 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root, 57 Type &rootType, 58 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings, 59 SmallVectorImpl<Type> &extraBindingTypes); 60 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, 61 Value root, Type rootType, 62 ValueRange extraBindings, 63 TypeRange extraBindingTypes); 64 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, 65 ArrayAttr matchers, ArrayAttr actions); 66 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, 67 ArrayAttr &matchers, 68 ArrayAttr &actions); 69 70 /// Helper function to check if the given transform op is contained in (or 71 /// equal to) the given payload target op. In that case, an error is returned. 72 /// Transforming transform IR that is currently executing is generally unsafe. 73 static DiagnosedSilenceableFailure 74 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, 75 Operation *payload) { 76 Operation *transformAncestor = transform.getOperation(); 77 while (transformAncestor) { 78 if (transformAncestor == payload) { 79 DiagnosedDefiniteFailure diag = 80 transform.emitDefiniteFailure() 81 << "cannot apply transform to itself (or one of its ancestors)"; 82 diag.attachNote(payload->getLoc()) << "target payload op"; 83 return diag; 84 } 85 transformAncestor = transformAncestor->getParentOp(); 86 } 87 return DiagnosedSilenceableFailure::success(); 88 } 89 90 #define GET_OP_CLASSES 91 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 92 93 //===----------------------------------------------------------------------===// 94 // AlternativesOp 95 //===----------------------------------------------------------------------===// 96 97 OperandRange 98 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { 99 if (!point.isParent() && getOperation()->getNumOperands() == 1) 100 return getOperation()->getOperands(); 101 return OperandRange(getOperation()->operand_end(), 102 getOperation()->operand_end()); 103 } 104 105 void transform::AlternativesOp::getSuccessorRegions( 106 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 107 for (Region &alternative : llvm::drop_begin( 108 getAlternatives(), 109 point.isParent() ? 0 110 : point.getRegionOrNull()->getRegionNumber() + 1)) { 111 regions.emplace_back(&alternative, !getOperands().empty() 112 ? alternative.getArguments() 113 : Block::BlockArgListType()); 114 } 115 if (!point.isParent()) 116 regions.emplace_back(getOperation()->getResults()); 117 } 118 119 void transform::AlternativesOp::getRegionInvocationBounds( 120 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 121 (void)operands; 122 // The region corresponding to the first alternative is always executed, the 123 // remaining may or may not be executed. 124 bounds.reserve(getNumRegions()); 125 bounds.emplace_back(1, 1); 126 bounds.resize(getNumRegions(), InvocationBounds(0, 1)); 127 } 128 129 static void forwardEmptyOperands(Block *block, transform::TransformState &state, 130 transform::TransformResults &results) { 131 for (const auto &res : block->getParentOp()->getOpResults()) 132 results.set(res, {}); 133 } 134 135 DiagnosedSilenceableFailure 136 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, 137 transform::TransformResults &results, 138 transform::TransformState &state) { 139 SmallVector<Operation *> originals; 140 if (Value scopeHandle = getScope()) 141 llvm::append_range(originals, state.getPayloadOps(scopeHandle)); 142 else 143 originals.push_back(state.getTopLevel()); 144 145 for (Operation *original : originals) { 146 if (original->isAncestor(getOperation())) { 147 auto diag = emitDefiniteFailure() 148 << "scope must not contain the transforms being applied"; 149 diag.attachNote(original->getLoc()) << "scope"; 150 return diag; 151 } 152 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 153 auto diag = emitDefiniteFailure() 154 << "only isolated-from-above ops can be alternative scopes"; 155 diag.attachNote(original->getLoc()) << "scope"; 156 return diag; 157 } 158 } 159 160 for (Region ® : getAlternatives()) { 161 // Clone the scope operations and make the transforms in this alternative 162 // region apply to them by virtue of mapping the block argument (the only 163 // visible handle) to the cloned scope operations. This effectively prevents 164 // the transformation from accessing any IR outside the scope. 165 auto scope = state.make_region_scope(reg); 166 auto clones = llvm::to_vector( 167 llvm::map_range(originals, [](Operation *op) { return op->clone(); })); 168 auto deleteClones = llvm::make_scope_exit([&] { 169 for (Operation *clone : clones) 170 clone->erase(); 171 }); 172 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) 173 return DiagnosedSilenceableFailure::definiteFailure(); 174 175 bool failed = false; 176 for (Operation &transform : reg.front().without_terminator()) { 177 DiagnosedSilenceableFailure result = 178 state.applyTransform(cast<TransformOpInterface>(transform)); 179 if (result.isSilenceableFailure()) { 180 LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() 181 << "\n"); 182 failed = true; 183 break; 184 } 185 186 if (::mlir::failed(result.silence())) 187 return DiagnosedSilenceableFailure::definiteFailure(); 188 } 189 190 // If all operations in the given alternative succeeded, no need to consider 191 // the rest. Replace the original scoping operation with the clone on which 192 // the transformations were performed. 193 if (!failed) { 194 // We will be using the clones, so cancel their scheduled deletion. 195 deleteClones.release(); 196 TrackingListener listener(state, *this); 197 IRRewriter rewriter(getContext(), &listener); 198 for (const auto &kvp : llvm::zip(originals, clones)) { 199 Operation *original = std::get<0>(kvp); 200 Operation *clone = std::get<1>(kvp); 201 original->getBlock()->getOperations().insert(original->getIterator(), 202 clone); 203 rewriter.replaceOp(original, clone->getResults()); 204 } 205 detail::forwardTerminatorOperands(®.front(), state, results); 206 return DiagnosedSilenceableFailure::success(); 207 } 208 } 209 return emitSilenceableError() << "all alternatives failed"; 210 } 211 212 void transform::AlternativesOp::getEffects( 213 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 214 consumesHandle(getOperation()->getOpOperands(), effects); 215 producesHandle(getOperation()->getOpResults(), effects); 216 for (Region *region : getRegions()) { 217 if (!region->empty()) 218 producesHandle(region->front().getArguments(), effects); 219 } 220 modifiesPayload(effects); 221 } 222 223 LogicalResult transform::AlternativesOp::verify() { 224 for (Region &alternative : getAlternatives()) { 225 Block &block = alternative.front(); 226 Operation *terminator = block.getTerminator(); 227 if (terminator->getOperands().getTypes() != getResults().getTypes()) { 228 InFlightDiagnostic diag = emitOpError() 229 << "expects terminator operands to have the " 230 "same type as results of the operation"; 231 diag.attachNote(terminator->getLoc()) << "terminator"; 232 return diag; 233 } 234 } 235 236 return success(); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // AnnotateOp 241 //===----------------------------------------------------------------------===// 242 243 DiagnosedSilenceableFailure 244 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter, 245 transform::TransformResults &results, 246 transform::TransformState &state) { 247 SmallVector<Operation *> targets = 248 llvm::to_vector(state.getPayloadOps(getTarget())); 249 250 Attribute attr = UnitAttr::get(getContext()); 251 if (auto paramH = getParam()) { 252 ArrayRef<Attribute> params = state.getParams(paramH); 253 if (params.size() != 1) { 254 if (targets.size() != params.size()) { 255 return emitSilenceableError() 256 << "parameter and target have different payload lengths (" 257 << params.size() << " vs " << targets.size() << ")"; 258 } 259 for (auto &&[target, attr] : llvm::zip_equal(targets, params)) 260 target->setAttr(getName(), attr); 261 return DiagnosedSilenceableFailure::success(); 262 } 263 attr = params[0]; 264 } 265 for (auto *target : targets) 266 target->setAttr(getName(), attr); 267 return DiagnosedSilenceableFailure::success(); 268 } 269 270 void transform::AnnotateOp::getEffects( 271 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 272 onlyReadsHandle(getTargetMutable(), effects); 273 onlyReadsHandle(getParamMutable(), effects); 274 modifiesPayload(effects); 275 } 276 277 //===----------------------------------------------------------------------===// 278 // ApplyCommonSubexpressionEliminationOp 279 //===----------------------------------------------------------------------===// 280 281 DiagnosedSilenceableFailure 282 transform::ApplyCommonSubexpressionEliminationOp::applyToOne( 283 transform::TransformRewriter &rewriter, Operation *target, 284 ApplyToEachResultList &results, transform::TransformState &state) { 285 // Make sure that this transform is not applied to itself. Modifying the 286 // transform IR while it is being interpreted is generally dangerous. 287 DiagnosedSilenceableFailure payloadCheck = 288 ensurePayloadIsSeparateFromTransform(*this, target); 289 if (!payloadCheck.succeeded()) 290 return payloadCheck; 291 292 DominanceInfo domInfo; 293 mlir::eliminateCommonSubExpressions(rewriter, domInfo, target); 294 return DiagnosedSilenceableFailure::success(); 295 } 296 297 void transform::ApplyCommonSubexpressionEliminationOp::getEffects( 298 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 299 transform::onlyReadsHandle(getTargetMutable(), effects); 300 transform::modifiesPayload(effects); 301 } 302 303 //===----------------------------------------------------------------------===// 304 // ApplyDeadCodeEliminationOp 305 //===----------------------------------------------------------------------===// 306 307 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne( 308 transform::TransformRewriter &rewriter, Operation *target, 309 ApplyToEachResultList &results, transform::TransformState &state) { 310 // Make sure that this transform is not applied to itself. Modifying the 311 // transform IR while it is being interpreted is generally dangerous. 312 DiagnosedSilenceableFailure payloadCheck = 313 ensurePayloadIsSeparateFromTransform(*this, target); 314 if (!payloadCheck.succeeded()) 315 return payloadCheck; 316 317 // Maintain a worklist of potentially dead ops. 318 SetVector<Operation *> worklist; 319 320 // Helper function that adds all defining ops of used values (operands and 321 // operands of nested ops). 322 auto addDefiningOpsToWorklist = [&](Operation *op) { 323 op->walk([&](Operation *op) { 324 for (Value v : op->getOperands()) 325 if (Operation *defOp = v.getDefiningOp()) 326 if (target->isProperAncestor(defOp)) 327 worklist.insert(defOp); 328 }); 329 }; 330 331 // Helper function that erases an op. 332 auto eraseOp = [&](Operation *op) { 333 // Remove op and nested ops from the worklist. 334 op->walk([&](Operation *op) { 335 const auto *it = llvm::find(worklist, op); 336 if (it != worklist.end()) 337 worklist.erase(it); 338 }); 339 rewriter.eraseOp(op); 340 }; 341 342 // Initial walk over the IR. 343 target->walk<WalkOrder::PostOrder>([&](Operation *op) { 344 if (op != target && isOpTriviallyDead(op)) { 345 addDefiningOpsToWorklist(op); 346 eraseOp(op); 347 } 348 }); 349 350 // Erase all ops that have become dead. 351 while (!worklist.empty()) { 352 Operation *op = worklist.pop_back_val(); 353 if (!isOpTriviallyDead(op)) 354 continue; 355 addDefiningOpsToWorklist(op); 356 eraseOp(op); 357 } 358 359 return DiagnosedSilenceableFailure::success(); 360 } 361 362 void transform::ApplyDeadCodeEliminationOp::getEffects( 363 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 364 transform::onlyReadsHandle(getTargetMutable(), effects); 365 transform::modifiesPayload(effects); 366 } 367 368 //===----------------------------------------------------------------------===// 369 // ApplyPatternsOp 370 //===----------------------------------------------------------------------===// 371 372 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( 373 transform::TransformRewriter &rewriter, Operation *target, 374 ApplyToEachResultList &results, transform::TransformState &state) { 375 // Make sure that this transform is not applied to itself. Modifying the 376 // transform IR while it is being interpreted is generally dangerous. Even 377 // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver 378 // performs many additional simplifications such as dead code elimination. 379 DiagnosedSilenceableFailure payloadCheck = 380 ensurePayloadIsSeparateFromTransform(*this, target); 381 if (!payloadCheck.succeeded()) 382 return payloadCheck; 383 384 // Gather all specified patterns. 385 MLIRContext *ctx = target->getContext(); 386 RewritePatternSet patterns(ctx); 387 if (!getRegion().empty()) { 388 for (Operation &op : getRegion().front()) { 389 cast<transform::PatternDescriptorOpInterface>(&op) 390 .populatePatternsWithState(patterns, state); 391 } 392 } 393 394 // Configure the GreedyPatternRewriteDriver. 395 GreedyRewriteConfig config; 396 config.listener = 397 static_cast<RewriterBase::Listener *>(rewriter.getListener()); 398 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 399 400 config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1) 401 ? GreedyRewriteConfig::kNoLimit 402 : getMaxIterations(); 403 config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1) 404 ? GreedyRewriteConfig::kNoLimit 405 : getMaxNumRewrites(); 406 407 // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE 408 // was requested, apply the greedy pattern rewrite only once. (The greedy 409 // pattern rewrite driver already iterates to a fixpoint internally.) 410 bool cseChanged = false; 411 // One or two iterations should be sufficient. Stop iterating after a certain 412 // threshold to make debugging easier. 413 static const int64_t kNumMaxIterations = 50; 414 int64_t iteration = 0; 415 do { 416 LogicalResult result = failure(); 417 if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 418 // Op is isolated from above. Apply patterns and also perform region 419 // simplification. 420 result = applyPatternsGreedily(target, frozenPatterns, config); 421 } else { 422 // Manually gather list of ops because the other 423 // GreedyPatternRewriteDriver overloads only accepts ops that are isolated 424 // from above. This way, patterns can be applied to ops that are not 425 // isolated from above. Regions are not being simplified. Furthermore, 426 // only a single greedy rewrite iteration is performed. 427 SmallVector<Operation *> ops; 428 target->walk([&](Operation *nestedOp) { 429 if (target != nestedOp) 430 ops.push_back(nestedOp); 431 }); 432 result = applyOpPatternsGreedily(ops, frozenPatterns, config); 433 } 434 435 // A failure typically indicates that the pattern application did not 436 // converge. 437 if (failed(result)) { 438 return emitSilenceableFailure(target) 439 << "greedy pattern application failed"; 440 } 441 442 if (getApplyCse()) { 443 DominanceInfo domInfo; 444 mlir::eliminateCommonSubExpressions(rewriter, domInfo, target, 445 &cseChanged); 446 } 447 } while (cseChanged && ++iteration < kNumMaxIterations); 448 449 if (iteration == kNumMaxIterations) 450 return emitDefiniteFailure() << "fixpoint iteration did not converge"; 451 452 return DiagnosedSilenceableFailure::success(); 453 } 454 455 LogicalResult transform::ApplyPatternsOp::verify() { 456 if (!getRegion().empty()) { 457 for (Operation &op : getRegion().front()) { 458 if (!isa<transform::PatternDescriptorOpInterface>(&op)) { 459 InFlightDiagnostic diag = emitOpError() 460 << "expected children ops to implement " 461 "PatternDescriptorOpInterface"; 462 diag.attachNote(op.getLoc()) << "op without interface"; 463 return diag; 464 } 465 } 466 } 467 return success(); 468 } 469 470 void transform::ApplyPatternsOp::getEffects( 471 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 472 transform::onlyReadsHandle(getTargetMutable(), effects); 473 transform::modifiesPayload(effects); 474 } 475 476 void transform::ApplyPatternsOp::build( 477 OpBuilder &builder, OperationState &result, Value target, 478 function_ref<void(OpBuilder &, Location)> bodyBuilder) { 479 result.addOperands(target); 480 481 OpBuilder::InsertionGuard g(builder); 482 Region *region = result.addRegion(); 483 builder.createBlock(region); 484 if (bodyBuilder) 485 bodyBuilder(builder, result.location); 486 } 487 488 //===----------------------------------------------------------------------===// 489 // ApplyCanonicalizationPatternsOp 490 //===----------------------------------------------------------------------===// 491 492 void transform::ApplyCanonicalizationPatternsOp::populatePatterns( 493 RewritePatternSet &patterns) { 494 MLIRContext *ctx = patterns.getContext(); 495 for (Dialect *dialect : ctx->getLoadedDialects()) 496 dialect->getCanonicalizationPatterns(patterns); 497 for (RegisteredOperationName op : ctx->getRegisteredOperations()) 498 op.getCanonicalizationPatterns(patterns, ctx); 499 } 500 501 //===----------------------------------------------------------------------===// 502 // ApplyConversionPatternsOp 503 //===----------------------------------------------------------------------===// 504 505 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( 506 transform::TransformRewriter &rewriter, 507 transform::TransformResults &results, transform::TransformState &state) { 508 MLIRContext *ctx = getContext(); 509 510 // Instantiate the default type converter if a type converter builder is 511 // specified. 512 std::unique_ptr<TypeConverter> defaultTypeConverter; 513 transform::TypeConverterBuilderOpInterface typeConverterBuilder = 514 getDefaultTypeConverter(); 515 if (typeConverterBuilder) 516 defaultTypeConverter = typeConverterBuilder.getTypeConverter(); 517 518 // Configure conversion target. 519 ConversionTarget conversionTarget(*getContext()); 520 if (getLegalOps()) 521 for (Attribute attr : cast<ArrayAttr>(*getLegalOps())) 522 conversionTarget.addLegalOp( 523 OperationName(cast<StringAttr>(attr).getValue(), ctx)); 524 if (getIllegalOps()) 525 for (Attribute attr : cast<ArrayAttr>(*getIllegalOps())) 526 conversionTarget.addIllegalOp( 527 OperationName(cast<StringAttr>(attr).getValue(), ctx)); 528 if (getLegalDialects()) 529 for (Attribute attr : cast<ArrayAttr>(*getLegalDialects())) 530 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue()); 531 if (getIllegalDialects()) 532 for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects())) 533 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue()); 534 535 // Gather all specified patterns. 536 RewritePatternSet patterns(ctx); 537 // Need to keep the converters alive until after pattern application because 538 // the patterns take a reference to an object that would otherwise get out of 539 // scope. 540 SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters; 541 if (!getPatterns().empty()) { 542 for (Operation &op : getPatterns().front()) { 543 auto descriptor = 544 cast<transform::ConversionPatternDescriptorOpInterface>(&op); 545 546 // Check if this pattern set specifies a type converter. 547 std::unique_ptr<TypeConverter> typeConverter = 548 descriptor.getTypeConverter(); 549 TypeConverter *converter = nullptr; 550 if (typeConverter) { 551 keepAliveConverters.emplace_back(std::move(typeConverter)); 552 converter = keepAliveConverters.back().get(); 553 } else { 554 // No type converter specified: Use the default type converter. 555 if (!defaultTypeConverter) { 556 auto diag = emitDefiniteFailure() 557 << "pattern descriptor does not specify type " 558 "converter and apply_conversion_patterns op has " 559 "no default type converter"; 560 diag.attachNote(op.getLoc()) << "pattern descriptor op"; 561 return diag; 562 } 563 converter = defaultTypeConverter.get(); 564 } 565 566 // Add descriptor-specific updates to the conversion target, which may 567 // depend on the final type converter. In structural converters, the 568 // legality of types dictates the dynamic legality of an operation. 569 descriptor.populateConversionTargetRules(*converter, conversionTarget); 570 571 descriptor.populatePatterns(*converter, patterns); 572 } 573 } 574 575 // Attach a tracking listener if handles should be preserved. We configure the 576 // listener to allow op replacements with different names, as conversion 577 // patterns typically replace ops with replacement ops that have a different 578 // name. 579 TrackingListenerConfig trackingConfig; 580 trackingConfig.requireMatchingReplacementOpName = false; 581 ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig); 582 ConversionConfig conversionConfig; 583 if (getPreserveHandles()) 584 conversionConfig.listener = &trackingListener; 585 586 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 587 for (Operation *target : state.getPayloadOps(getTarget())) { 588 // Make sure that this transform is not applied to itself. Modifying the 589 // transform IR while it is being interpreted is generally dangerous. 590 DiagnosedSilenceableFailure payloadCheck = 591 ensurePayloadIsSeparateFromTransform(*this, target); 592 if (!payloadCheck.succeeded()) 593 return payloadCheck; 594 595 LogicalResult status = failure(); 596 if (getPartialConversion()) { 597 status = applyPartialConversion(target, conversionTarget, frozenPatterns, 598 conversionConfig); 599 } else { 600 status = applyFullConversion(target, conversionTarget, frozenPatterns, 601 conversionConfig); 602 } 603 604 // Check dialect conversion state. 605 DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); 606 if (failed(status)) { 607 diag = emitSilenceableError() << "dialect conversion failed"; 608 diag.attachNote(target->getLoc()) << "target op"; 609 } 610 611 // Check tracking listener error state. 612 DiagnosedSilenceableFailure trackingFailure = 613 trackingListener.checkAndResetError(); 614 if (!trackingFailure.succeeded()) { 615 if (diag.succeeded()) { 616 // Tracking failure is the only failure. 617 return trackingFailure; 618 } else { 619 diag.attachNote() << "tracking listener also failed: " 620 << trackingFailure.getMessage(); 621 (void)trackingFailure.silence(); 622 } 623 } 624 625 if (!diag.succeeded()) 626 return diag; 627 } 628 629 return DiagnosedSilenceableFailure::success(); 630 } 631 632 LogicalResult transform::ApplyConversionPatternsOp::verify() { 633 if (getNumRegions() != 1 && getNumRegions() != 2) 634 return emitOpError() << "expected 1 or 2 regions"; 635 if (!getPatterns().empty()) { 636 for (Operation &op : getPatterns().front()) { 637 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) { 638 InFlightDiagnostic diag = 639 emitOpError() << "expected pattern children ops to implement " 640 "ConversionPatternDescriptorOpInterface"; 641 diag.attachNote(op.getLoc()) << "op without interface"; 642 return diag; 643 } 644 } 645 } 646 if (getNumRegions() == 2) { 647 Region &typeConverterRegion = getRegion(1); 648 if (!llvm::hasSingleElement(typeConverterRegion.front())) 649 return emitOpError() 650 << "expected exactly one op in default type converter region"; 651 Operation *maybeTypeConverter = &typeConverterRegion.front().front(); 652 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>( 653 maybeTypeConverter); 654 if (!typeConverterOp) { 655 InFlightDiagnostic diag = emitOpError() 656 << "expected default converter child op to " 657 "implement TypeConverterBuilderOpInterface"; 658 diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface"; 659 return diag; 660 } 661 // Check default type converter type. 662 if (!getPatterns().empty()) { 663 for (Operation &op : getPatterns().front()) { 664 auto descriptor = 665 cast<transform::ConversionPatternDescriptorOpInterface>(&op); 666 if (failed(descriptor.verifyTypeConverter(typeConverterOp))) 667 return failure(); 668 } 669 } 670 } 671 return success(); 672 } 673 674 void transform::ApplyConversionPatternsOp::getEffects( 675 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 676 if (!getPreserveHandles()) { 677 transform::consumesHandle(getTargetMutable(), effects); 678 } else { 679 transform::onlyReadsHandle(getTargetMutable(), effects); 680 } 681 transform::modifiesPayload(effects); 682 } 683 684 void transform::ApplyConversionPatternsOp::build( 685 OpBuilder &builder, OperationState &result, Value target, 686 function_ref<void(OpBuilder &, Location)> patternsBodyBuilder, 687 function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) { 688 result.addOperands(target); 689 690 { 691 OpBuilder::InsertionGuard g(builder); 692 Region *region1 = result.addRegion(); 693 builder.createBlock(region1); 694 if (patternsBodyBuilder) 695 patternsBodyBuilder(builder, result.location); 696 } 697 { 698 OpBuilder::InsertionGuard g(builder); 699 Region *region2 = result.addRegion(); 700 builder.createBlock(region2); 701 if (typeConverterBodyBuilder) 702 typeConverterBodyBuilder(builder, result.location); 703 } 704 } 705 706 //===----------------------------------------------------------------------===// 707 // ApplyToLLVMConversionPatternsOp 708 //===----------------------------------------------------------------------===// 709 710 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns( 711 TypeConverter &typeConverter, RewritePatternSet &patterns) { 712 Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); 713 assert(dialect && "expected that dialect is loaded"); 714 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect); 715 // ConversionTarget is currently ignored because the enclosing 716 // apply_conversion_patterns op sets up its own ConversionTarget. 717 ConversionTarget target(*getContext()); 718 iface->populateConvertToLLVMConversionPatterns( 719 target, static_cast<LLVMTypeConverter &>(typeConverter), patterns); 720 } 721 722 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter( 723 transform::TypeConverterBuilderOpInterface builder) { 724 if (builder.getTypeConverterType() != "LLVMTypeConverter") 725 return emitOpError("expected LLVMTypeConverter"); 726 return success(); 727 } 728 729 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() { 730 Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); 731 if (!dialect) 732 return emitOpError("unknown dialect or dialect not loaded: ") 733 << getDialectName(); 734 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); 735 if (!iface) 736 return emitOpError( 737 "dialect does not implement ConvertToLLVMPatternInterface or " 738 "extension was not loaded: ") 739 << getDialectName(); 740 return success(); 741 } 742 743 //===----------------------------------------------------------------------===// 744 // ApplyLoopInvariantCodeMotionOp 745 //===----------------------------------------------------------------------===// 746 747 DiagnosedSilenceableFailure 748 transform::ApplyLoopInvariantCodeMotionOp::applyToOne( 749 transform::TransformRewriter &rewriter, LoopLikeOpInterface target, 750 transform::ApplyToEachResultList &results, 751 transform::TransformState &state) { 752 // Currently, LICM does not remove operations, so we don't need tracking. 753 // If this ever changes, add a LICM entry point that takes a rewriter. 754 moveLoopInvariantCode(target); 755 return DiagnosedSilenceableFailure::success(); 756 } 757 758 void transform::ApplyLoopInvariantCodeMotionOp::getEffects( 759 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 760 transform::onlyReadsHandle(getTargetMutable(), effects); 761 transform::modifiesPayload(effects); 762 } 763 764 //===----------------------------------------------------------------------===// 765 // ApplyRegisteredPassOp 766 //===----------------------------------------------------------------------===// 767 768 DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne( 769 transform::TransformRewriter &rewriter, Operation *target, 770 ApplyToEachResultList &results, transform::TransformState &state) { 771 // Make sure that this transform is not applied to itself. Modifying the 772 // transform IR while it is being interpreted is generally dangerous. Even 773 // more so when applying passes because they may perform a wide range of IR 774 // modifications. 775 DiagnosedSilenceableFailure payloadCheck = 776 ensurePayloadIsSeparateFromTransform(*this, target); 777 if (!payloadCheck.succeeded()) 778 return payloadCheck; 779 780 // Get pass or pass pipeline from registry. 781 const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName()); 782 if (!info) 783 info = PassInfo::lookup(getPassName()); 784 if (!info) 785 return emitDefiniteFailure() 786 << "unknown pass or pass pipeline: " << getPassName(); 787 788 // Create pass manager and run the pass or pass pipeline. 789 PassManager pm(getContext()); 790 if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) { 791 emitError(msg); 792 return failure(); 793 }))) { 794 return emitDefiniteFailure() 795 << "failed to add pass or pass pipeline to pipeline: " 796 << getPassName(); 797 } 798 if (failed(pm.run(target))) { 799 auto diag = emitSilenceableError() << "pass pipeline failed"; 800 diag.attachNote(target->getLoc()) << "target op"; 801 return diag; 802 } 803 804 results.push_back(target); 805 return DiagnosedSilenceableFailure::success(); 806 } 807 808 //===----------------------------------------------------------------------===// 809 // CastOp 810 //===----------------------------------------------------------------------===// 811 812 DiagnosedSilenceableFailure 813 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter, 814 Operation *target, ApplyToEachResultList &results, 815 transform::TransformState &state) { 816 results.push_back(target); 817 return DiagnosedSilenceableFailure::success(); 818 } 819 820 void transform::CastOp::getEffects( 821 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 822 onlyReadsPayload(effects); 823 onlyReadsHandle(getInputMutable(), effects); 824 producesHandle(getOperation()->getOpResults(), effects); 825 } 826 827 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 828 assert(inputs.size() == 1 && "expected one input"); 829 assert(outputs.size() == 1 && "expected one output"); 830 return llvm::all_of( 831 std::initializer_list<Type>{inputs.front(), outputs.front()}, 832 llvm::IsaPred<transform::TransformHandleTypeInterface>); 833 } 834 835 //===----------------------------------------------------------------------===// 836 // CollectMatchingOp 837 //===----------------------------------------------------------------------===// 838 839 /// Applies matcher operations from the given `block` using 840 /// `blockArgumentMapping` to initialize block arguments. Updates `state` 841 /// accordingly. If any of the matcher produces a silenceable failure, discards 842 /// it (printing the content to the debug output stream) and returns failure. If 843 /// any of the matchers produces a definite failure, reports it and returns 844 /// failure. If all matchers in the block succeed, populates `mappings` with the 845 /// payload entities associated with the block terminator operands. Note that 846 /// `mappings` will be cleared before that. 847 static DiagnosedSilenceableFailure 848 matchBlock(Block &block, 849 ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping, 850 transform::TransformState &state, 851 SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) { 852 assert(block.getParent() && "cannot match using a detached block"); 853 auto matchScope = state.make_region_scope(*block.getParent()); 854 if (failed( 855 state.mapBlockArguments(block.getArguments(), blockArgumentMapping))) 856 return DiagnosedSilenceableFailure::definiteFailure(); 857 858 for (Operation &match : block.without_terminator()) { 859 if (!isa<transform::MatchOpInterface>(match)) { 860 return emitDefiniteFailure(match.getLoc()) 861 << "expected operations in the match part to " 862 "implement MatchOpInterface"; 863 } 864 DiagnosedSilenceableFailure diag = 865 state.applyTransform(cast<transform::TransformOpInterface>(match)); 866 if (diag.succeeded()) 867 continue; 868 869 return diag; 870 } 871 872 // Remember the values mapped to the terminator operands so we can 873 // forward them to the action. 874 ValueRange yieldedValues = block.getTerminator()->getOperands(); 875 // Our contract with the caller is that the mappings will contain only the 876 // newly mapped values, clear the rest. 877 mappings.clear(); 878 transform::detail::prepareValueMappings(mappings, yieldedValues, state); 879 return DiagnosedSilenceableFailure::success(); 880 } 881 882 /// Returns `true` if both types implement one of the interfaces provided as 883 /// template parameters. 884 template <typename... Tys> 885 static bool implementSameInterface(Type t1, Type t2) { 886 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false); 887 } 888 889 /// Returns `true` if both types implement one of the transform dialect 890 /// interfaces. 891 static bool implementSameTransformInterface(Type t1, Type t2) { 892 return implementSameInterface<transform::TransformHandleTypeInterface, 893 transform::TransformParamTypeInterface, 894 transform::TransformValueHandleTypeInterface>( 895 t1, t2); 896 } 897 898 //===----------------------------------------------------------------------===// 899 // CollectMatchingOp 900 //===----------------------------------------------------------------------===// 901 902 DiagnosedSilenceableFailure 903 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, 904 transform::TransformResults &results, 905 transform::TransformState &state) { 906 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>( 907 getOperation(), getMatcher()); 908 if (matcher.isExternal()) { 909 return emitDefiniteFailure() 910 << "unresolved external symbol " << getMatcher(); 911 } 912 913 SmallVector<SmallVector<MappedValue>, 2> rawResults; 914 rawResults.resize(getOperation()->getNumResults()); 915 std::optional<DiagnosedSilenceableFailure> maybeFailure; 916 for (Operation *root : state.getPayloadOps(getRoot())) { 917 WalkResult walkResult = root->walk([&](Operation *op) { 918 DEBUG_MATCHER({ 919 DBGS_MATCHER() << "matching "; 920 op->print(llvm::dbgs(), 921 OpPrintingFlags().assumeVerified().skipRegions()); 922 llvm::dbgs() << " @" << op << "\n"; 923 }); 924 925 // Try matching. 926 SmallVector<SmallVector<MappedValue>> mappings; 927 SmallVector<transform::MappedValue> inputMapping({op}); 928 DiagnosedSilenceableFailure diag = matchBlock( 929 matcher.getFunctionBody().front(), 930 ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state, 931 mappings); 932 if (diag.isDefiniteFailure()) 933 return WalkResult::interrupt(); 934 if (diag.isSilenceableFailure()) { 935 DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() 936 << " failed: " << diag.getMessage()); 937 return WalkResult::advance(); 938 } 939 940 // If succeeded, collect results. 941 for (auto &&[i, mapping] : llvm::enumerate(mappings)) { 942 if (mapping.size() != 1) { 943 maybeFailure.emplace(emitSilenceableError() 944 << "result #" << i << ", associated with " 945 << mapping.size() 946 << " payload objects, expected 1"); 947 return WalkResult::interrupt(); 948 } 949 rawResults[i].push_back(mapping[0]); 950 } 951 return WalkResult::advance(); 952 }); 953 if (walkResult.wasInterrupted()) 954 return std::move(*maybeFailure); 955 assert(!maybeFailure && "failure set but the walk was not interrupted"); 956 957 for (auto &&[opResult, rawResult] : 958 llvm::zip_equal(getOperation()->getResults(), rawResults)) { 959 results.setMappedValues(opResult, rawResult); 960 } 961 } 962 return DiagnosedSilenceableFailure::success(); 963 } 964 965 void transform::CollectMatchingOp::getEffects( 966 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 967 onlyReadsHandle(getRootMutable(), effects); 968 producesHandle(getOperation()->getOpResults(), effects); 969 onlyReadsPayload(effects); 970 } 971 972 LogicalResult transform::CollectMatchingOp::verifySymbolUses( 973 SymbolTableCollection &symbolTable) { 974 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( 975 symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher())); 976 if (!matcherSymbol || 977 !isa<TransformOpInterface>(matcherSymbol.getOperation())) 978 return emitError() << "unresolved matcher symbol " << getMatcher(); 979 980 ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes(); 981 if (argumentTypes.size() != 1 || 982 !isa<TransformHandleTypeInterface>(argumentTypes[0])) { 983 return emitError() 984 << "expected the matcher to take one operation handle argument"; 985 } 986 if (!matcherSymbol.getArgAttr( 987 0, transform::TransformDialect::kArgReadOnlyAttrName)) { 988 return emitError() << "expected the matcher argument to be marked readonly"; 989 } 990 991 ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes(); 992 if (resultTypes.size() != getOperation()->getNumResults()) { 993 return emitError() 994 << "expected the matcher to yield as many values as op has results (" 995 << getOperation()->getNumResults() << "), got " 996 << resultTypes.size(); 997 } 998 999 for (auto &&[i, matcherType, resultType] : 1000 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) { 1001 if (implementSameTransformInterface(matcherType, resultType)) 1002 continue; 1003 1004 return emitError() 1005 << "mismatching type interfaces for matcher result and op result #" 1006 << i; 1007 } 1008 1009 return success(); 1010 } 1011 1012 //===----------------------------------------------------------------------===// 1013 // ForeachMatchOp 1014 //===----------------------------------------------------------------------===// 1015 1016 // This is fine because nothing is actually consumed by this op. 1017 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; } 1018 1019 DiagnosedSilenceableFailure 1020 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, 1021 transform::TransformResults &results, 1022 transform::TransformState &state) { 1023 SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>> 1024 matchActionPairs; 1025 matchActionPairs.reserve(getMatchers().size()); 1026 SymbolTableCollection symbolTable; 1027 for (auto &&[matcher, action] : 1028 llvm::zip_equal(getMatchers(), getActions())) { 1029 auto matcherSymbol = 1030 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>( 1031 getOperation(), cast<SymbolRefAttr>(matcher)); 1032 auto actionSymbol = 1033 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>( 1034 getOperation(), cast<SymbolRefAttr>(action)); 1035 assert(matcherSymbol && actionSymbol && 1036 "unresolved symbols not caught by the verifier"); 1037 1038 if (matcherSymbol.isExternal()) 1039 return emitDefiniteFailure() << "unresolved external symbol " << matcher; 1040 if (actionSymbol.isExternal()) 1041 return emitDefiniteFailure() << "unresolved external symbol " << action; 1042 1043 matchActionPairs.emplace_back(matcherSymbol, actionSymbol); 1044 } 1045 1046 DiagnosedSilenceableFailure overallDiag = 1047 DiagnosedSilenceableFailure::success(); 1048 1049 SmallVector<SmallVector<MappedValue>> matchInputMapping; 1050 SmallVector<SmallVector<MappedValue>> matchOutputMapping; 1051 SmallVector<SmallVector<MappedValue>> actionResultMapping; 1052 // Explicitly add the mapping for the first block argument (the op being 1053 // matched). 1054 matchInputMapping.emplace_back(); 1055 transform::detail::prepareValueMappings(matchInputMapping, 1056 getForwardedInputs(), state); 1057 SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front(); 1058 actionResultMapping.resize(getForwardedOutputs().size()); 1059 1060 for (Operation *root : state.getPayloadOps(getRoot())) { 1061 WalkResult walkResult = root->walk([&](Operation *op) { 1062 // If getRestrictRoot is not present, skip over the root op itself so we 1063 // don't invalidate it. 1064 if (!getRestrictRoot() && op == root) 1065 return WalkResult::advance(); 1066 1067 DEBUG_MATCHER({ 1068 DBGS_MATCHER() << "matching "; 1069 op->print(llvm::dbgs(), 1070 OpPrintingFlags().assumeVerified().skipRegions()); 1071 llvm::dbgs() << " @" << op << "\n"; 1072 }); 1073 1074 firstMatchArgument.clear(); 1075 firstMatchArgument.push_back(op); 1076 1077 // Try all the match/action pairs until the first successful match. 1078 for (auto [matcher, action] : matchActionPairs) { 1079 DiagnosedSilenceableFailure diag = 1080 matchBlock(matcher.getFunctionBody().front(), matchInputMapping, 1081 state, matchOutputMapping); 1082 if (diag.isDefiniteFailure()) 1083 return WalkResult::interrupt(); 1084 if (diag.isSilenceableFailure()) { 1085 DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() 1086 << " failed: " << diag.getMessage()); 1087 continue; 1088 } 1089 1090 auto scope = state.make_region_scope(action.getFunctionBody()); 1091 if (failed(state.mapBlockArguments( 1092 action.getFunctionBody().front().getArguments(), 1093 matchOutputMapping))) { 1094 return WalkResult::interrupt(); 1095 } 1096 1097 for (Operation &transform : 1098 action.getFunctionBody().front().without_terminator()) { 1099 DiagnosedSilenceableFailure result = 1100 state.applyTransform(cast<TransformOpInterface>(transform)); 1101 if (result.isDefiniteFailure()) 1102 return WalkResult::interrupt(); 1103 if (result.isSilenceableFailure()) { 1104 if (overallDiag.succeeded()) { 1105 overallDiag = emitSilenceableError() << "actions failed"; 1106 } 1107 overallDiag.attachNote(action->getLoc()) 1108 << "failed action: " << result.getMessage(); 1109 overallDiag.attachNote(op->getLoc()) 1110 << "when applied to this matching payload"; 1111 (void)result.silence(); 1112 continue; 1113 } 1114 } 1115 if (failed(detail::appendValueMappings( 1116 MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping), 1117 action.getFunctionBody().front().getTerminator()->getOperands(), 1118 state, getFlattenResults()))) { 1119 emitDefiniteFailure() 1120 << "action @" << action.getName() 1121 << " has results associated with multiple payload entities, " 1122 "but flattening was not requested"; 1123 return WalkResult::interrupt(); 1124 } 1125 break; 1126 } 1127 return WalkResult::advance(); 1128 }); 1129 if (walkResult.wasInterrupted()) 1130 return DiagnosedSilenceableFailure::definiteFailure(); 1131 } 1132 1133 // The root operation should not have been affected, so we can just reassign 1134 // the payload to the result. Note that we need to consume the root handle to 1135 // make sure any handles to operations inside, that could have been affected 1136 // by actions, are invalidated. 1137 results.set(llvm::cast<OpResult>(getUpdated()), 1138 state.getPayloadOps(getRoot())); 1139 for (auto &&[result, mapping] : 1140 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) { 1141 results.setMappedValues(result, mapping); 1142 } 1143 return overallDiag; 1144 } 1145 1146 void transform::ForeachMatchOp::getAsmResultNames( 1147 OpAsmSetValueNameFn setNameFn) { 1148 setNameFn(getUpdated(), "updated_root"); 1149 for (Value v : getForwardedOutputs()) { 1150 setNameFn(v, "yielded"); 1151 } 1152 } 1153 1154 void transform::ForeachMatchOp::getEffects( 1155 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1156 // Bail if invalid. 1157 if (getOperation()->getNumOperands() < 1 || 1158 getOperation()->getNumResults() < 1) { 1159 return modifiesPayload(effects); 1160 } 1161 1162 consumesHandle(getRootMutable(), effects); 1163 onlyReadsHandle(getForwardedInputsMutable(), effects); 1164 producesHandle(getOperation()->getOpResults(), effects); 1165 modifiesPayload(effects); 1166 } 1167 1168 /// Parses the comma-separated list of symbol reference pairs of the format 1169 /// `@matcher -> @action`. 1170 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, 1171 ArrayAttr &matchers, 1172 ArrayAttr &actions) { 1173 StringAttr matcher; 1174 StringAttr action; 1175 SmallVector<Attribute> matcherList; 1176 SmallVector<Attribute> actionList; 1177 do { 1178 if (parser.parseSymbolName(matcher) || parser.parseArrow() || 1179 parser.parseSymbolName(action)) { 1180 return failure(); 1181 } 1182 matcherList.push_back(SymbolRefAttr::get(matcher)); 1183 actionList.push_back(SymbolRefAttr::get(action)); 1184 } while (parser.parseOptionalComma().succeeded()); 1185 1186 matchers = parser.getBuilder().getArrayAttr(matcherList); 1187 actions = parser.getBuilder().getArrayAttr(actionList); 1188 return success(); 1189 } 1190 1191 /// Prints the comma-separated list of symbol reference pairs of the format 1192 /// `@matcher -> @action`. 1193 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, 1194 ArrayAttr matchers, ArrayAttr actions) { 1195 printer.increaseIndent(); 1196 printer.increaseIndent(); 1197 for (auto &&[matcher, action, idx] : llvm::zip_equal( 1198 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) { 1199 printer.printNewline(); 1200 printer << cast<SymbolRefAttr>(matcher) << " -> " 1201 << cast<SymbolRefAttr>(action); 1202 if (idx != matchers.size() - 1) 1203 printer << ", "; 1204 } 1205 printer.decreaseIndent(); 1206 printer.decreaseIndent(); 1207 } 1208 1209 LogicalResult transform::ForeachMatchOp::verify() { 1210 if (getMatchers().size() != getActions().size()) 1211 return emitOpError() << "expected the same number of matchers and actions"; 1212 if (getMatchers().empty()) 1213 return emitOpError() << "expected at least one match/action pair"; 1214 1215 llvm::SmallPtrSet<Attribute, 8> matcherNames; 1216 for (Attribute name : getMatchers()) { 1217 if (matcherNames.insert(name).second) 1218 continue; 1219 emitWarning() << "matcher " << name 1220 << " is used more than once, only the first match will apply"; 1221 } 1222 1223 return success(); 1224 } 1225 1226 /// Checks that the attributes of the function-like operation have correct 1227 /// consumption effect annotations. If `alsoVerifyInternal`, checks for 1228 /// annotations being present even if they can be inferred from the body. 1229 static DiagnosedSilenceableFailure 1230 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, 1231 bool alsoVerifyInternal = false) { 1232 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation()); 1233 llvm::SmallDenseSet<unsigned> consumedArguments; 1234 if (!op.isExternal()) { 1235 transform::getConsumedBlockArguments(op.getFunctionBody().front(), 1236 consumedArguments); 1237 } 1238 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) { 1239 bool isConsumed = 1240 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) != 1241 nullptr; 1242 bool isReadOnly = 1243 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) != 1244 nullptr; 1245 if (isConsumed && isReadOnly) { 1246 return transformOp.emitSilenceableError() 1247 << "argument #" << i << " cannot be both readonly and consumed"; 1248 } 1249 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) { 1250 return transformOp.emitSilenceableError() 1251 << "must provide consumed/readonly status for arguments of " 1252 "external or called ops"; 1253 } 1254 if (op.isExternal()) 1255 continue; 1256 1257 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) { 1258 return transformOp.emitSilenceableError() 1259 << "argument #" << i 1260 << " is consumed in the body but is not marked as such"; 1261 } 1262 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) { 1263 // Cannot use op.emitWarning() here as it would attempt to verify the op 1264 // before printing, resulting in infinite recursion. 1265 emitWarning(op->getLoc()) 1266 << "op argument #" << i 1267 << " is not consumed in the body but is marked as consumed"; 1268 } 1269 } 1270 return DiagnosedSilenceableFailure::success(); 1271 } 1272 1273 LogicalResult transform::ForeachMatchOp::verifySymbolUses( 1274 SymbolTableCollection &symbolTable) { 1275 assert(getMatchers().size() == getActions().size()); 1276 auto consumedAttr = 1277 StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName); 1278 for (auto &&[matcher, action] : 1279 llvm::zip_equal(getMatchers(), getActions())) { 1280 // Presence and typing. 1281 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( 1282 symbolTable.lookupNearestSymbolFrom(getOperation(), 1283 cast<SymbolRefAttr>(matcher))); 1284 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>( 1285 symbolTable.lookupNearestSymbolFrom(getOperation(), 1286 cast<SymbolRefAttr>(action))); 1287 if (!matcherSymbol || 1288 !isa<TransformOpInterface>(matcherSymbol.getOperation())) 1289 return emitError() << "unresolved matcher symbol " << matcher; 1290 if (!actionSymbol || 1291 !isa<TransformOpInterface>(actionSymbol.getOperation())) 1292 return emitError() << "unresolved action symbol " << action; 1293 1294 if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol, 1295 /*emitWarnings=*/false, 1296 /*alsoVerifyInternal=*/true) 1297 .checkAndReport())) { 1298 return failure(); 1299 } 1300 if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol, 1301 /*emitWarnings=*/false, 1302 /*alsoVerifyInternal=*/true) 1303 .checkAndReport())) { 1304 return failure(); 1305 } 1306 1307 // Input -> matcher forwarding. 1308 TypeRange operandTypes = getOperandTypes(); 1309 TypeRange matcherArguments = matcherSymbol.getArgumentTypes(); 1310 if (operandTypes.size() != matcherArguments.size()) { 1311 InFlightDiagnostic diag = 1312 emitError() << "the number of operands (" << operandTypes.size() 1313 << ") doesn't match the number of matcher arguments (" 1314 << matcherArguments.size() << ") for " << matcher; 1315 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; 1316 return diag; 1317 } 1318 for (auto &&[i, operand, argument] : 1319 llvm::enumerate(operandTypes, matcherArguments)) { 1320 if (matcherSymbol.getArgAttr(i, consumedAttr)) { 1321 InFlightDiagnostic diag = 1322 emitOpError() 1323 << "does not expect matcher symbol to consume its operand #" << i; 1324 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; 1325 return diag; 1326 } 1327 1328 if (implementSameTransformInterface(operand, argument)) 1329 continue; 1330 1331 InFlightDiagnostic diag = 1332 emitError() 1333 << "mismatching type interfaces for operand and matcher argument #" 1334 << i << " of matcher " << matcher; 1335 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; 1336 return diag; 1337 } 1338 1339 // Matcher -> action forwarding. 1340 TypeRange matcherResults = matcherSymbol.getResultTypes(); 1341 TypeRange actionArguments = actionSymbol.getArgumentTypes(); 1342 if (matcherResults.size() != actionArguments.size()) { 1343 return emitError() << "mismatching number of matcher results and " 1344 "action arguments between " 1345 << matcher << " (" << matcherResults.size() << ") and " 1346 << action << " (" << actionArguments.size() << ")"; 1347 } 1348 for (auto &&[i, matcherType, actionType] : 1349 llvm::enumerate(matcherResults, actionArguments)) { 1350 if (implementSameTransformInterface(matcherType, actionType)) 1351 continue; 1352 1353 return emitError() << "mismatching type interfaces for matcher result " 1354 "and action argument #" 1355 << i << "of matcher " << matcher << " and action " 1356 << action; 1357 } 1358 1359 // Action -> result forwarding. 1360 TypeRange actionResults = actionSymbol.getResultTypes(); 1361 auto resultTypes = TypeRange(getResultTypes()).drop_front(); 1362 if (actionResults.size() != resultTypes.size()) { 1363 InFlightDiagnostic diag = 1364 emitError() << "the number of action results (" 1365 << actionResults.size() << ") for " << action 1366 << " doesn't match the number of extra op results (" 1367 << resultTypes.size() << ")"; 1368 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; 1369 return diag; 1370 } 1371 for (auto &&[i, resultType, actionType] : 1372 llvm::enumerate(resultTypes, actionResults)) { 1373 if (implementSameTransformInterface(resultType, actionType)) 1374 continue; 1375 1376 InFlightDiagnostic diag = 1377 emitError() << "mismatching type interfaces for action result #" << i 1378 << " of action " << action << " and op result"; 1379 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; 1380 return diag; 1381 } 1382 } 1383 return success(); 1384 } 1385 1386 //===----------------------------------------------------------------------===// 1387 // ForeachOp 1388 //===----------------------------------------------------------------------===// 1389 1390 DiagnosedSilenceableFailure 1391 transform::ForeachOp::apply(transform::TransformRewriter &rewriter, 1392 transform::TransformResults &results, 1393 transform::TransformState &state) { 1394 // We store the payloads before executing the body as ops may be removed from 1395 // the mapping by the TrackingRewriter while iteration is in progress. 1396 SmallVector<SmallVector<MappedValue>> payloads; 1397 detail::prepareValueMappings(payloads, getTargets(), state); 1398 size_t numIterations = payloads.empty() ? 0 : payloads.front().size(); 1399 bool withZipShortest = getWithZipShortest(); 1400 1401 // In case of `zip_shortest`, set the number of iterations to the 1402 // smallest payload in the targets. 1403 if (withZipShortest) { 1404 numIterations = 1405 llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A, 1406 const SmallVector<MappedValue> &B) { 1407 return A.size() < B.size(); 1408 })->size(); 1409 1410 for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++) 1411 payloads[argIdx].resize(numIterations); 1412 } 1413 1414 // As we will be "zipping" over them, check all payloads have the same size. 1415 // `zip_shortest` adjusts all payloads to the same size, so skip this check 1416 // when true. 1417 for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size(); 1418 argIdx++) { 1419 if (payloads[argIdx].size() != numIterations) { 1420 return emitSilenceableError() 1421 << "prior targets' payload size (" << numIterations 1422 << ") differs from payload size (" << payloads[argIdx].size() 1423 << ") of target " << getTargets()[argIdx]; 1424 } 1425 } 1426 1427 // Start iterating, indexing into payloads to obtain the right arguments to 1428 // call the body with - each slice of payloads at the same argument index 1429 // corresponding to a tuple to use as the body's block arguments. 1430 ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments(); 1431 SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {}); 1432 for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) { 1433 auto scope = state.make_region_scope(getBody()); 1434 // Set up arguments to the region's block. 1435 for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) { 1436 MappedValue argument = payloads[argIdx][iterIdx]; 1437 // Note that each blockArg's handle gets associated with just a single 1438 // element from the corresponding target's payload. 1439 if (failed(state.mapBlockArgument(blockArg, {argument}))) 1440 return DiagnosedSilenceableFailure::definiteFailure(); 1441 } 1442 1443 // Execute loop body. 1444 for (Operation &transform : getBody().front().without_terminator()) { 1445 DiagnosedSilenceableFailure result = state.applyTransform( 1446 llvm::cast<transform::TransformOpInterface>(transform)); 1447 if (!result.succeeded()) 1448 return result; 1449 } 1450 1451 // Append yielded payloads to corresponding results from prior iterations. 1452 OperandRange yieldOperands = getYieldOp().getOperands(); 1453 for (auto &&[result, yieldOperand, resTuple] : 1454 llvm::zip_equal(getResults(), yieldOperands, zippedResults)) 1455 // NB: each iteration we add any number of ops/vals/params to a result. 1456 if (isa<TransformHandleTypeInterface>(result.getType())) 1457 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand)); 1458 else if (isa<TransformValueHandleTypeInterface>(result.getType())) 1459 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand)); 1460 else if (isa<TransformParamTypeInterface>(result.getType())) 1461 llvm::append_range(resTuple, state.getParams(yieldOperand)); 1462 else 1463 assert(false && "unhandled handle type"); 1464 } 1465 1466 // Associate the accumulated result payloads to the op's actual results. 1467 for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults)) 1468 results.setMappedValues(llvm::cast<OpResult>(result), resPayload); 1469 1470 return DiagnosedSilenceableFailure::success(); 1471 } 1472 1473 void transform::ForeachOp::getEffects( 1474 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1475 // NB: this `zip` should be `zip_equal` - while this op's verifier catches 1476 // arity errors, this method might get called before/in absence of `verify()`. 1477 for (auto &&[target, blockArg] : 1478 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) { 1479 BlockArgument blockArgument = blockArg; 1480 if (any_of(getBody().front().without_terminator(), [&](Operation &op) { 1481 return isHandleConsumed(blockArgument, 1482 cast<TransformOpInterface>(&op)); 1483 })) { 1484 consumesHandle(target, effects); 1485 } else { 1486 onlyReadsHandle(target, effects); 1487 } 1488 } 1489 1490 if (any_of(getBody().front().without_terminator(), [&](Operation &op) { 1491 return doesModifyPayload(cast<TransformOpInterface>(&op)); 1492 })) { 1493 modifiesPayload(effects); 1494 } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) { 1495 return doesReadPayload(cast<TransformOpInterface>(&op)); 1496 })) { 1497 onlyReadsPayload(effects); 1498 } 1499 1500 producesHandle(getOperation()->getOpResults(), effects); 1501 } 1502 1503 void transform::ForeachOp::getSuccessorRegions( 1504 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 1505 Region *bodyRegion = &getBody(); 1506 if (point.isParent()) { 1507 regions.emplace_back(bodyRegion, bodyRegion->getArguments()); 1508 return; 1509 } 1510 1511 // Branch back to the region or the parent. 1512 assert(point == getBody() && "unexpected region index"); 1513 regions.emplace_back(bodyRegion, bodyRegion->getArguments()); 1514 regions.emplace_back(); 1515 } 1516 1517 OperandRange 1518 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { 1519 // Each block argument handle is mapped to a subset (one op to be precise) 1520 // of the payload of the corresponding `targets` operand of ForeachOp. 1521 assert(point == getBody() && "unexpected region index"); 1522 return getOperation()->getOperands(); 1523 } 1524 1525 transform::YieldOp transform::ForeachOp::getYieldOp() { 1526 return cast<transform::YieldOp>(getBody().front().getTerminator()); 1527 } 1528 1529 LogicalResult transform::ForeachOp::verify() { 1530 for (auto [targetOpt, bodyArgOpt] : 1531 llvm::zip_longest(getTargets(), getBody().front().getArguments())) { 1532 if (!targetOpt || !bodyArgOpt) 1533 return emitOpError() << "expects the same number of targets as the body " 1534 "has block arguments"; 1535 if (targetOpt.value().getType() != bodyArgOpt.value().getType()) 1536 return emitOpError( 1537 "expects co-indexed targets and the body's " 1538 "block arguments to have the same op/value/param type"); 1539 } 1540 1541 for (auto [resultOpt, yieldOperandOpt] : 1542 llvm::zip_longest(getResults(), getYieldOp().getOperands())) { 1543 if (!resultOpt || !yieldOperandOpt) 1544 return emitOpError() << "expects the same number of results as the " 1545 "yield terminator has operands"; 1546 if (resultOpt.value().getType() != yieldOperandOpt.value().getType()) 1547 return emitOpError("expects co-indexed results and yield " 1548 "operands to have the same op/value/param type"); 1549 } 1550 1551 return success(); 1552 } 1553 1554 //===----------------------------------------------------------------------===// 1555 // GetParentOp 1556 //===----------------------------------------------------------------------===// 1557 1558 DiagnosedSilenceableFailure 1559 transform::GetParentOp::apply(transform::TransformRewriter &rewriter, 1560 transform::TransformResults &results, 1561 transform::TransformState &state) { 1562 SmallVector<Operation *> parents; 1563 DenseSet<Operation *> resultSet; 1564 for (Operation *target : state.getPayloadOps(getTarget())) { 1565 Operation *parent = target; 1566 for (int64_t i = 0, e = getNthParent(); i < e; ++i) { 1567 parent = parent->getParentOp(); 1568 while (parent) { 1569 bool checkIsolatedFromAbove = 1570 !getIsolatedFromAbove() || 1571 parent->hasTrait<OpTrait::IsIsolatedFromAbove>(); 1572 bool checkOpName = !getOpName().has_value() || 1573 parent->getName().getStringRef() == *getOpName(); 1574 if (checkIsolatedFromAbove && checkOpName) 1575 break; 1576 parent = parent->getParentOp(); 1577 } 1578 if (!parent) { 1579 if (getAllowEmptyResults()) { 1580 results.set(llvm::cast<OpResult>(getResult()), parents); 1581 return DiagnosedSilenceableFailure::success(); 1582 } 1583 DiagnosedSilenceableFailure diag = 1584 emitSilenceableError() 1585 << "could not find a parent op that matches all requirements"; 1586 diag.attachNote(target->getLoc()) << "target op"; 1587 return diag; 1588 } 1589 } 1590 if (getDeduplicate()) { 1591 if (resultSet.insert(parent).second) 1592 parents.push_back(parent); 1593 } else { 1594 parents.push_back(parent); 1595 } 1596 } 1597 results.set(llvm::cast<OpResult>(getResult()), parents); 1598 return DiagnosedSilenceableFailure::success(); 1599 } 1600 1601 //===----------------------------------------------------------------------===// 1602 // GetConsumersOfResult 1603 //===----------------------------------------------------------------------===// 1604 1605 DiagnosedSilenceableFailure 1606 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter, 1607 transform::TransformResults &results, 1608 transform::TransformState &state) { 1609 int64_t resultNumber = getResultNumber(); 1610 auto payloadOps = state.getPayloadOps(getTarget()); 1611 if (std::empty(payloadOps)) { 1612 results.set(cast<OpResult>(getResult()), {}); 1613 return DiagnosedSilenceableFailure::success(); 1614 } 1615 if (!llvm::hasSingleElement(payloadOps)) 1616 return emitDefiniteFailure() 1617 << "handle must be mapped to exactly one payload op"; 1618 1619 Operation *target = *payloadOps.begin(); 1620 if (target->getNumResults() <= resultNumber) 1621 return emitDefiniteFailure() << "result number overflow"; 1622 results.set(llvm::cast<OpResult>(getResult()), 1623 llvm::to_vector(target->getResult(resultNumber).getUsers())); 1624 return DiagnosedSilenceableFailure::success(); 1625 } 1626 1627 //===----------------------------------------------------------------------===// 1628 // GetDefiningOp 1629 //===----------------------------------------------------------------------===// 1630 1631 DiagnosedSilenceableFailure 1632 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter, 1633 transform::TransformResults &results, 1634 transform::TransformState &state) { 1635 SmallVector<Operation *> definingOps; 1636 for (Value v : state.getPayloadValues(getTarget())) { 1637 if (llvm::isa<BlockArgument>(v)) { 1638 DiagnosedSilenceableFailure diag = 1639 emitSilenceableError() << "cannot get defining op of block argument"; 1640 diag.attachNote(v.getLoc()) << "target value"; 1641 return diag; 1642 } 1643 definingOps.push_back(v.getDefiningOp()); 1644 } 1645 results.set(llvm::cast<OpResult>(getResult()), definingOps); 1646 return DiagnosedSilenceableFailure::success(); 1647 } 1648 1649 //===----------------------------------------------------------------------===// 1650 // GetProducerOfOperand 1651 //===----------------------------------------------------------------------===// 1652 1653 DiagnosedSilenceableFailure 1654 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter, 1655 transform::TransformResults &results, 1656 transform::TransformState &state) { 1657 int64_t operandNumber = getOperandNumber(); 1658 SmallVector<Operation *> producers; 1659 for (Operation *target : state.getPayloadOps(getTarget())) { 1660 Operation *producer = 1661 target->getNumOperands() <= operandNumber 1662 ? nullptr 1663 : target->getOperand(operandNumber).getDefiningOp(); 1664 if (!producer) { 1665 DiagnosedSilenceableFailure diag = 1666 emitSilenceableError() 1667 << "could not find a producer for operand number: " << operandNumber 1668 << " of " << *target; 1669 diag.attachNote(target->getLoc()) << "target op"; 1670 return diag; 1671 } 1672 producers.push_back(producer); 1673 } 1674 results.set(llvm::cast<OpResult>(getResult()), producers); 1675 return DiagnosedSilenceableFailure::success(); 1676 } 1677 1678 //===----------------------------------------------------------------------===// 1679 // GetOperandOp 1680 //===----------------------------------------------------------------------===// 1681 1682 DiagnosedSilenceableFailure 1683 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter, 1684 transform::TransformResults &results, 1685 transform::TransformState &state) { 1686 SmallVector<Value> operands; 1687 for (Operation *target : state.getPayloadOps(getTarget())) { 1688 SmallVector<int64_t> operandPositions; 1689 DiagnosedSilenceableFailure diag = expandTargetSpecification( 1690 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), 1691 target->getNumOperands(), operandPositions); 1692 if (diag.isSilenceableFailure()) { 1693 diag.attachNote(target->getLoc()) 1694 << "while considering positions of this payload operation"; 1695 return diag; 1696 } 1697 llvm::append_range(operands, 1698 llvm::map_range(operandPositions, [&](int64_t pos) { 1699 return target->getOperand(pos); 1700 })); 1701 } 1702 results.setValues(cast<OpResult>(getResult()), operands); 1703 return DiagnosedSilenceableFailure::success(); 1704 } 1705 1706 LogicalResult transform::GetOperandOp::verify() { 1707 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), 1708 getIsInverted(), getIsAll()); 1709 } 1710 1711 //===----------------------------------------------------------------------===// 1712 // GetResultOp 1713 //===----------------------------------------------------------------------===// 1714 1715 DiagnosedSilenceableFailure 1716 transform::GetResultOp::apply(transform::TransformRewriter &rewriter, 1717 transform::TransformResults &results, 1718 transform::TransformState &state) { 1719 SmallVector<Value> opResults; 1720 for (Operation *target : state.getPayloadOps(getTarget())) { 1721 SmallVector<int64_t> resultPositions; 1722 DiagnosedSilenceableFailure diag = expandTargetSpecification( 1723 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), 1724 target->getNumResults(), resultPositions); 1725 if (diag.isSilenceableFailure()) { 1726 diag.attachNote(target->getLoc()) 1727 << "while considering positions of this payload operation"; 1728 return diag; 1729 } 1730 llvm::append_range(opResults, 1731 llvm::map_range(resultPositions, [&](int64_t pos) { 1732 return target->getResult(pos); 1733 })); 1734 } 1735 results.setValues(cast<OpResult>(getResult()), opResults); 1736 return DiagnosedSilenceableFailure::success(); 1737 } 1738 1739 LogicalResult transform::GetResultOp::verify() { 1740 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), 1741 getIsInverted(), getIsAll()); 1742 } 1743 1744 //===----------------------------------------------------------------------===// 1745 // GetTypeOp 1746 //===----------------------------------------------------------------------===// 1747 1748 void transform::GetTypeOp::getEffects( 1749 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1750 onlyReadsHandle(getValueMutable(), effects); 1751 producesHandle(getOperation()->getOpResults(), effects); 1752 onlyReadsPayload(effects); 1753 } 1754 1755 DiagnosedSilenceableFailure 1756 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter, 1757 transform::TransformResults &results, 1758 transform::TransformState &state) { 1759 SmallVector<Attribute> params; 1760 for (Value value : state.getPayloadValues(getValue())) { 1761 Type type = value.getType(); 1762 if (getElemental()) { 1763 if (auto shaped = dyn_cast<ShapedType>(type)) { 1764 type = shaped.getElementType(); 1765 } 1766 } 1767 params.push_back(TypeAttr::get(type)); 1768 } 1769 results.setParams(cast<OpResult>(getResult()), params); 1770 return DiagnosedSilenceableFailure::success(); 1771 } 1772 1773 //===----------------------------------------------------------------------===// 1774 // IncludeOp 1775 //===----------------------------------------------------------------------===// 1776 1777 /// Applies the transform ops contained in `block`. Maps `results` to the same 1778 /// values as the operands of the block terminator. 1779 static DiagnosedSilenceableFailure 1780 applySequenceBlock(Block &block, transform::FailurePropagationMode mode, 1781 transform::TransformState &state, 1782 transform::TransformResults &results) { 1783 // Apply the sequenced ops one by one. 1784 for (Operation &transform : block.without_terminator()) { 1785 DiagnosedSilenceableFailure result = 1786 state.applyTransform(cast<transform::TransformOpInterface>(transform)); 1787 if (result.isDefiniteFailure()) 1788 return result; 1789 1790 if (result.isSilenceableFailure()) { 1791 if (mode == transform::FailurePropagationMode::Propagate) { 1792 // Propagate empty results in case of early exit. 1793 forwardEmptyOperands(&block, state, results); 1794 return result; 1795 } 1796 (void)result.silence(); 1797 } 1798 } 1799 1800 // Forward the operation mapping for values yielded from the sequence to the 1801 // values produced by the sequence op. 1802 transform::detail::forwardTerminatorOperands(&block, state, results); 1803 return DiagnosedSilenceableFailure::success(); 1804 } 1805 1806 DiagnosedSilenceableFailure 1807 transform::IncludeOp::apply(transform::TransformRewriter &rewriter, 1808 transform::TransformResults &results, 1809 transform::TransformState &state) { 1810 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>( 1811 getOperation(), getTarget()); 1812 assert(callee && "unverified reference to unknown symbol"); 1813 1814 if (callee.isExternal()) 1815 return emitDefiniteFailure() << "unresolved external named sequence"; 1816 1817 // Map operands to block arguments. 1818 SmallVector<SmallVector<MappedValue>> mappings; 1819 detail::prepareValueMappings(mappings, getOperands(), state); 1820 auto scope = state.make_region_scope(callee.getBody()); 1821 for (auto &&[arg, map] : 1822 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) { 1823 if (failed(state.mapBlockArgument(arg, map))) 1824 return DiagnosedSilenceableFailure::definiteFailure(); 1825 } 1826 1827 DiagnosedSilenceableFailure result = applySequenceBlock( 1828 callee.getBody().front(), getFailurePropagationMode(), state, results); 1829 mappings.clear(); 1830 detail::prepareValueMappings( 1831 mappings, callee.getBody().front().getTerminator()->getOperands(), state); 1832 for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings)) 1833 results.setMappedValues(result, mapping); 1834 return result; 1835 } 1836 1837 static DiagnosedSilenceableFailure 1838 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings); 1839 1840 void transform::IncludeOp::getEffects( 1841 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1842 // Always mark as modifying the payload. 1843 // TODO: a mechanism to annotate effects on payload. Even when all handles are 1844 // only read, the payload may still be modified, so we currently stay on the 1845 // conservative side and always indicate modification. This may prevent some 1846 // code reordering. 1847 modifiesPayload(effects); 1848 1849 // Results are always produced. 1850 producesHandle(getOperation()->getOpResults(), effects); 1851 1852 // Adds default effects to operands and results. This will be added if 1853 // preconditions fail so the trait verifier doesn't complain about missing 1854 // effects and the real precondition failure is reported later on. 1855 auto defaultEffects = [&] { 1856 onlyReadsHandle(getOperation()->getOpOperands(), effects); 1857 }; 1858 1859 // Bail if the callee is unknown. This may run as part of the verification 1860 // process before we verified the validity of the callee or of this op. 1861 auto target = 1862 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName()); 1863 if (!target) 1864 return defaultEffects(); 1865 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>( 1866 getOperation(), getTarget()); 1867 if (!callee) 1868 return defaultEffects(); 1869 DiagnosedSilenceableFailure earlyVerifierResult = 1870 verifyNamedSequenceOp(callee, /*emitWarnings=*/false); 1871 if (!earlyVerifierResult.succeeded()) { 1872 (void)earlyVerifierResult.silence(); 1873 return defaultEffects(); 1874 } 1875 1876 for (unsigned i = 0, e = getNumOperands(); i < e; ++i) { 1877 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName)) 1878 consumesHandle(getOperation()->getOpOperand(i), effects); 1879 else 1880 onlyReadsHandle(getOperation()->getOpOperand(i), effects); 1881 } 1882 } 1883 1884 LogicalResult 1885 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1886 // Access through indirection and do additional checking because this may be 1887 // running before the main op verifier. 1888 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target"); 1889 if (!targetAttr) 1890 return emitOpError() << "expects a 'target' symbol reference attribute"; 1891 1892 auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>( 1893 *this, targetAttr); 1894 if (!target) 1895 return emitOpError() << "does not reference a named transform sequence"; 1896 1897 FunctionType fnType = target.getFunctionType(); 1898 if (fnType.getNumInputs() != getNumOperands()) 1899 return emitError("incorrect number of operands for callee"); 1900 1901 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { 1902 if (getOperand(i).getType() != fnType.getInput(i)) { 1903 return emitOpError("operand type mismatch: expected operand type ") 1904 << fnType.getInput(i) << ", but provided " 1905 << getOperand(i).getType() << " for operand number " << i; 1906 } 1907 } 1908 1909 if (fnType.getNumResults() != getNumResults()) 1910 return emitError("incorrect number of results for callee"); 1911 1912 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { 1913 Type resultType = getResult(i).getType(); 1914 Type funcType = fnType.getResult(i); 1915 if (!implementSameTransformInterface(resultType, funcType)) { 1916 return emitOpError() << "type of result #" << i 1917 << " must implement the same transform dialect " 1918 "interface as the corresponding callee result"; 1919 } 1920 } 1921 1922 return verifyFunctionLikeConsumeAnnotations( 1923 cast<FunctionOpInterface>(*target), /*emitWarnings=*/false, 1924 /*alsoVerifyInternal=*/true) 1925 .checkAndReport(); 1926 } 1927 1928 //===----------------------------------------------------------------------===// 1929 // MatchOperationEmptyOp 1930 //===----------------------------------------------------------------------===// 1931 1932 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation( 1933 ::std::optional<::mlir::Operation *> maybeCurrent, 1934 transform::TransformResults &results, transform::TransformState &state) { 1935 if (!maybeCurrent.has_value()) { 1936 DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; }); 1937 return DiagnosedSilenceableFailure::success(); 1938 } 1939 DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; }); 1940 return emitSilenceableError() << "operation is not empty"; 1941 } 1942 1943 //===----------------------------------------------------------------------===// 1944 // MatchOperationNameOp 1945 //===----------------------------------------------------------------------===// 1946 1947 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation( 1948 Operation *current, transform::TransformResults &results, 1949 transform::TransformState &state) { 1950 StringRef currentOpName = current->getName().getStringRef(); 1951 for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) { 1952 if (acceptedAttr.getValue() == currentOpName) 1953 return DiagnosedSilenceableFailure::success(); 1954 } 1955 return emitSilenceableError() << "wrong operation name"; 1956 } 1957 1958 //===----------------------------------------------------------------------===// 1959 // MatchParamCmpIOp 1960 //===----------------------------------------------------------------------===// 1961 1962 DiagnosedSilenceableFailure 1963 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter, 1964 transform::TransformResults &results, 1965 transform::TransformState &state) { 1966 auto signedAPIntAsString = [&](const APInt &value) { 1967 std::string str; 1968 llvm::raw_string_ostream os(str); 1969 value.print(os, /*isSigned=*/true); 1970 return str; 1971 }; 1972 1973 ArrayRef<Attribute> params = state.getParams(getParam()); 1974 ArrayRef<Attribute> references = state.getParams(getReference()); 1975 1976 if (params.size() != references.size()) { 1977 return emitSilenceableError() 1978 << "parameters have different payload lengths (" << params.size() 1979 << " vs " << references.size() << ")"; 1980 } 1981 1982 for (auto &&[i, param, reference] : llvm::enumerate(params, references)) { 1983 auto intAttr = llvm::dyn_cast<IntegerAttr>(param); 1984 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference); 1985 if (!intAttr || !refAttr) { 1986 return emitDefiniteFailure() 1987 << "non-integer parameter value not expected"; 1988 } 1989 if (intAttr.getType() != refAttr.getType()) { 1990 return emitDefiniteFailure() 1991 << "mismatching integer attribute types in parameter #" << i; 1992 } 1993 APInt value = intAttr.getValue(); 1994 APInt refValue = refAttr.getValue(); 1995 1996 // TODO: this copy will not be necessary in C++20. 1997 int64_t position = i; 1998 auto reportError = [&](StringRef direction) { 1999 DiagnosedSilenceableFailure diag = 2000 emitSilenceableError() << "expected parameter to be " << direction 2001 << " " << signedAPIntAsString(refValue) 2002 << ", got " << signedAPIntAsString(value); 2003 diag.attachNote(getParam().getLoc()) 2004 << "value # " << position 2005 << " associated with the parameter defined here"; 2006 return diag; 2007 }; 2008 2009 switch (getPredicate()) { 2010 case MatchCmpIPredicate::eq: 2011 if (value.eq(refValue)) 2012 break; 2013 return reportError("equal to"); 2014 case MatchCmpIPredicate::ne: 2015 if (value.ne(refValue)) 2016 break; 2017 return reportError("not equal to"); 2018 case MatchCmpIPredicate::lt: 2019 if (value.slt(refValue)) 2020 break; 2021 return reportError("less than"); 2022 case MatchCmpIPredicate::le: 2023 if (value.sle(refValue)) 2024 break; 2025 return reportError("less than or equal to"); 2026 case MatchCmpIPredicate::gt: 2027 if (value.sgt(refValue)) 2028 break; 2029 return reportError("greater than"); 2030 case MatchCmpIPredicate::ge: 2031 if (value.sge(refValue)) 2032 break; 2033 return reportError("greater than or equal to"); 2034 } 2035 } 2036 return DiagnosedSilenceableFailure::success(); 2037 } 2038 2039 void transform::MatchParamCmpIOp::getEffects( 2040 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2041 onlyReadsHandle(getParamMutable(), effects); 2042 onlyReadsHandle(getReferenceMutable(), effects); 2043 } 2044 2045 //===----------------------------------------------------------------------===// 2046 // ParamConstantOp 2047 //===----------------------------------------------------------------------===// 2048 2049 DiagnosedSilenceableFailure 2050 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter, 2051 transform::TransformResults &results, 2052 transform::TransformState &state) { 2053 results.setParams(cast<OpResult>(getParam()), {getValue()}); 2054 return DiagnosedSilenceableFailure::success(); 2055 } 2056 2057 //===----------------------------------------------------------------------===// 2058 // MergeHandlesOp 2059 //===----------------------------------------------------------------------===// 2060 2061 DiagnosedSilenceableFailure 2062 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter, 2063 transform::TransformResults &results, 2064 transform::TransformState &state) { 2065 ValueRange handles = getHandles(); 2066 if (isa<TransformHandleTypeInterface>(handles.front().getType())) { 2067 SmallVector<Operation *> operations; 2068 for (Value operand : handles) 2069 llvm::append_range(operations, state.getPayloadOps(operand)); 2070 if (!getDeduplicate()) { 2071 results.set(llvm::cast<OpResult>(getResult()), operations); 2072 return DiagnosedSilenceableFailure::success(); 2073 } 2074 2075 SetVector<Operation *> uniqued(operations.begin(), operations.end()); 2076 results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef()); 2077 return DiagnosedSilenceableFailure::success(); 2078 } 2079 2080 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) { 2081 SmallVector<Attribute> attrs; 2082 for (Value attribute : handles) 2083 llvm::append_range(attrs, state.getParams(attribute)); 2084 if (!getDeduplicate()) { 2085 results.setParams(cast<OpResult>(getResult()), attrs); 2086 return DiagnosedSilenceableFailure::success(); 2087 } 2088 2089 SetVector<Attribute> uniqued(attrs.begin(), attrs.end()); 2090 results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef()); 2091 return DiagnosedSilenceableFailure::success(); 2092 } 2093 2094 assert( 2095 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) && 2096 "expected value handle type"); 2097 SmallVector<Value> payloadValues; 2098 for (Value value : handles) 2099 llvm::append_range(payloadValues, state.getPayloadValues(value)); 2100 if (!getDeduplicate()) { 2101 results.setValues(cast<OpResult>(getResult()), payloadValues); 2102 return DiagnosedSilenceableFailure::success(); 2103 } 2104 2105 SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end()); 2106 results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef()); 2107 return DiagnosedSilenceableFailure::success(); 2108 } 2109 2110 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() { 2111 // Handles may be the same if deduplicating is enabled. 2112 return getDeduplicate(); 2113 } 2114 2115 void transform::MergeHandlesOp::getEffects( 2116 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2117 onlyReadsHandle(getHandlesMutable(), effects); 2118 producesHandle(getOperation()->getOpResults(), effects); 2119 2120 // There are no effects on the Payload IR as this is only a handle 2121 // manipulation. 2122 } 2123 2124 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) { 2125 if (getDeduplicate() || getHandles().size() != 1) 2126 return {}; 2127 2128 // If deduplication is not required and there is only one operand, it can be 2129 // used directly instead of merging. 2130 return getHandles().front(); 2131 } 2132 2133 //===----------------------------------------------------------------------===// 2134 // NamedSequenceOp 2135 //===----------------------------------------------------------------------===// 2136 2137 DiagnosedSilenceableFailure 2138 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter, 2139 transform::TransformResults &results, 2140 transform::TransformState &state) { 2141 if (isExternal()) 2142 return emitDefiniteFailure() << "unresolved external named sequence"; 2143 2144 // Map the entry block argument to the list of operations. 2145 // Note: this is the same implementation as PossibleTopLevelTransformOp but 2146 // without attaching the interface / trait since that is tailored to a 2147 // dangling top-level op that does not get "called". 2148 auto scope = state.make_region_scope(getBody()); 2149 if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments( 2150 state, this->getOperation(), getBody()))) 2151 return DiagnosedSilenceableFailure::definiteFailure(); 2152 2153 return applySequenceBlock(getBody().front(), 2154 FailurePropagationMode::Propagate, state, results); 2155 } 2156 2157 void transform::NamedSequenceOp::getEffects( 2158 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 2159 2160 ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser, 2161 OperationState &result) { 2162 return function_interface_impl::parseFunctionOp( 2163 parser, result, /*allowVariadic=*/false, 2164 getFunctionTypeAttrName(result.name), 2165 [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results, 2166 function_interface_impl::VariadicFlag, 2167 std::string &) { return builder.getFunctionType(inputs, results); }, 2168 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 2169 } 2170 2171 void transform::NamedSequenceOp::print(OpAsmPrinter &printer) { 2172 function_interface_impl::printFunctionOp( 2173 printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false, 2174 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(), 2175 getResAttrsAttrName()); 2176 } 2177 2178 /// Verifies that a symbol function-like transform dialect operation has the 2179 /// signature and the terminator that have conforming types, i.e., types 2180 /// implementing the same transform dialect type interface. If `allowExternal` 2181 /// is set, allow external symbols (declarations) and don't check the terminator 2182 /// as it may not exist. 2183 static DiagnosedSilenceableFailure 2184 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) { 2185 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) { 2186 DiagnosedSilenceableFailure diag = 2187 emitSilenceableFailure(op) 2188 << "cannot be defined inside another transform op"; 2189 diag.attachNote(parent.getLoc()) << "ancestor transform op"; 2190 return diag; 2191 } 2192 2193 if (op.isExternal() || op.getFunctionBody().empty()) { 2194 if (allowExternal) 2195 return DiagnosedSilenceableFailure::success(); 2196 2197 return emitSilenceableFailure(op) << "cannot be external"; 2198 } 2199 2200 if (op.getFunctionBody().front().empty()) 2201 return emitSilenceableFailure(op) << "expected a non-empty body block"; 2202 2203 Operation *terminator = &op.getFunctionBody().front().back(); 2204 if (!isa<transform::YieldOp>(terminator)) { 2205 DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) 2206 << "expected '" 2207 << transform::YieldOp::getOperationName() 2208 << "' as terminator"; 2209 diag.attachNote(terminator->getLoc()) << "terminator"; 2210 return diag; 2211 } 2212 2213 if (terminator->getNumOperands() != op.getResultTypes().size()) { 2214 return emitSilenceableFailure(terminator) 2215 << "expected terminator to have as many operands as the parent op " 2216 "has results"; 2217 } 2218 for (auto [i, operandType, resultType] : llvm::zip_equal( 2219 llvm::seq<unsigned>(0, terminator->getNumOperands()), 2220 terminator->getOperands().getType(), op.getResultTypes())) { 2221 if (operandType == resultType) 2222 continue; 2223 return emitSilenceableFailure(terminator) 2224 << "the type of the terminator operand #" << i 2225 << " must match the type of the corresponding parent op result (" 2226 << operandType << " vs " << resultType << ")"; 2227 } 2228 2229 return DiagnosedSilenceableFailure::success(); 2230 } 2231 2232 /// Verification of a NamedSequenceOp. This does not report the error 2233 /// immediately, so it can be used to check for op's well-formedness before the 2234 /// verifier runs, e.g., during trait verification. 2235 static DiagnosedSilenceableFailure 2236 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) { 2237 if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) { 2238 if (!parent->getAttr( 2239 transform::TransformDialect::kWithNamedSequenceAttrName)) { 2240 DiagnosedSilenceableFailure diag = 2241 emitSilenceableFailure(op) 2242 << "expects the parent symbol table to have the '" 2243 << transform::TransformDialect::kWithNamedSequenceAttrName 2244 << "' attribute"; 2245 diag.attachNote(parent->getLoc()) << "symbol table operation"; 2246 return diag; 2247 } 2248 } 2249 2250 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) { 2251 DiagnosedSilenceableFailure diag = 2252 emitSilenceableFailure(op) 2253 << "cannot be defined inside another transform op"; 2254 diag.attachNote(parent.getLoc()) << "ancestor transform op"; 2255 return diag; 2256 } 2257 2258 if (op.isExternal() || op.getBody().empty()) 2259 return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op), 2260 emitWarnings); 2261 2262 if (op.getBody().front().empty()) 2263 return emitSilenceableFailure(op) << "expected a non-empty body block"; 2264 2265 Operation *terminator = &op.getBody().front().back(); 2266 if (!isa<transform::YieldOp>(terminator)) { 2267 DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) 2268 << "expected '" 2269 << transform::YieldOp::getOperationName() 2270 << "' as terminator"; 2271 diag.attachNote(terminator->getLoc()) << "terminator"; 2272 return diag; 2273 } 2274 2275 if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) { 2276 return emitSilenceableFailure(terminator) 2277 << "expected terminator to have as many operands as the parent op " 2278 "has results"; 2279 } 2280 for (auto [i, operandType, resultType] : 2281 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()), 2282 terminator->getOperands().getType(), 2283 op.getFunctionType().getResults())) { 2284 if (operandType == resultType) 2285 continue; 2286 return emitSilenceableFailure(terminator) 2287 << "the type of the terminator operand #" << i 2288 << " must match the type of the corresponding parent op result (" 2289 << operandType << " vs " << resultType << ")"; 2290 } 2291 2292 auto funcOp = cast<FunctionOpInterface>(*op); 2293 DiagnosedSilenceableFailure diag = 2294 verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings); 2295 if (!diag.succeeded()) 2296 return diag; 2297 2298 return verifyYieldingSingleBlockOp(funcOp, 2299 /*allowExternal=*/true); 2300 } 2301 2302 LogicalResult transform::NamedSequenceOp::verify() { 2303 // Actual verification happens in a separate function for reusability. 2304 return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport(); 2305 } 2306 2307 template <typename FnTy> 2308 static void buildSequenceBody(OpBuilder &builder, OperationState &state, 2309 Type bbArgType, TypeRange extraBindingTypes, 2310 FnTy bodyBuilder) { 2311 SmallVector<Type> types; 2312 types.reserve(1 + extraBindingTypes.size()); 2313 types.push_back(bbArgType); 2314 llvm::append_range(types, extraBindingTypes); 2315 2316 OpBuilder::InsertionGuard guard(builder); 2317 Region *region = state.regions.back().get(); 2318 Block *bodyBlock = 2319 builder.createBlock(region, region->begin(), types, 2320 SmallVector<Location>(types.size(), state.location)); 2321 2322 // Populate body. 2323 builder.setInsertionPointToStart(bodyBlock); 2324 if constexpr (llvm::function_traits<FnTy>::num_args == 3) { 2325 bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); 2326 } else { 2327 bodyBuilder(builder, state.location, bodyBlock->getArgument(0), 2328 bodyBlock->getArguments().drop_front()); 2329 } 2330 } 2331 2332 void transform::NamedSequenceOp::build(OpBuilder &builder, 2333 OperationState &state, StringRef symName, 2334 Type rootType, TypeRange resultTypes, 2335 SequenceBodyBuilderFn bodyBuilder, 2336 ArrayRef<NamedAttribute> attrs, 2337 ArrayRef<DictionaryAttr> argAttrs) { 2338 state.addAttribute(SymbolTable::getSymbolAttrName(), 2339 builder.getStringAttr(symName)); 2340 state.addAttribute(getFunctionTypeAttrName(state.name), 2341 TypeAttr::get(FunctionType::get(builder.getContext(), 2342 rootType, resultTypes))); 2343 state.attributes.append(attrs.begin(), attrs.end()); 2344 state.addRegion(); 2345 2346 buildSequenceBody(builder, state, rootType, 2347 /*extraBindingTypes=*/TypeRange(), bodyBuilder); 2348 } 2349 2350 //===----------------------------------------------------------------------===// 2351 // NumAssociationsOp 2352 //===----------------------------------------------------------------------===// 2353 2354 DiagnosedSilenceableFailure 2355 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter, 2356 transform::TransformResults &results, 2357 transform::TransformState &state) { 2358 size_t numAssociations = 2359 llvm::TypeSwitch<Type, size_t>(getHandle().getType()) 2360 .Case([&](TransformHandleTypeInterface opHandle) { 2361 return llvm::range_size(state.getPayloadOps(getHandle())); 2362 }) 2363 .Case([&](TransformValueHandleTypeInterface valueHandle) { 2364 return llvm::range_size(state.getPayloadValues(getHandle())); 2365 }) 2366 .Case([&](TransformParamTypeInterface param) { 2367 return llvm::range_size(state.getParams(getHandle())); 2368 }) 2369 .Default([](Type) { 2370 llvm_unreachable("unknown kind of transform dialect type"); 2371 return 0; 2372 }); 2373 results.setParams(cast<OpResult>(getNum()), 2374 rewriter.getI64IntegerAttr(numAssociations)); 2375 return DiagnosedSilenceableFailure::success(); 2376 } 2377 2378 LogicalResult transform::NumAssociationsOp::verify() { 2379 // Verify that the result type accepts an i64 attribute as payload. 2380 auto resultType = cast<TransformParamTypeInterface>(getNum().getType()); 2381 return resultType 2382 .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)}) 2383 .checkAndReport(); 2384 } 2385 2386 //===----------------------------------------------------------------------===// 2387 // SelectOp 2388 //===----------------------------------------------------------------------===// 2389 2390 DiagnosedSilenceableFailure 2391 transform::SelectOp::apply(transform::TransformRewriter &rewriter, 2392 transform::TransformResults &results, 2393 transform::TransformState &state) { 2394 SmallVector<Operation *> result; 2395 auto payloadOps = state.getPayloadOps(getTarget()); 2396 for (Operation *op : payloadOps) { 2397 if (op->getName().getStringRef() == getOpName()) 2398 result.push_back(op); 2399 } 2400 results.set(cast<OpResult>(getResult()), result); 2401 return DiagnosedSilenceableFailure::success(); 2402 } 2403 2404 //===----------------------------------------------------------------------===// 2405 // SplitHandleOp 2406 //===----------------------------------------------------------------------===// 2407 2408 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result, 2409 Value target, int64_t numResultHandles) { 2410 result.addOperands(target); 2411 result.addTypes(SmallVector<Type>(numResultHandles, target.getType())); 2412 } 2413 2414 DiagnosedSilenceableFailure 2415 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, 2416 transform::TransformResults &results, 2417 transform::TransformState &state) { 2418 int64_t numPayloads = 2419 llvm::TypeSwitch<Type, int64_t>(getHandle().getType()) 2420 .Case<TransformHandleTypeInterface>([&](auto x) { 2421 return llvm::range_size(state.getPayloadOps(getHandle())); 2422 }) 2423 .Case<TransformValueHandleTypeInterface>([&](auto x) { 2424 return llvm::range_size(state.getPayloadValues(getHandle())); 2425 }) 2426 .Case<TransformParamTypeInterface>([&](auto x) { 2427 return llvm::range_size(state.getParams(getHandle())); 2428 }) 2429 .Default([](auto x) { 2430 llvm_unreachable("unknown transform dialect type interface"); 2431 return -1; 2432 }); 2433 2434 auto produceNumOpsError = [&]() { 2435 return emitSilenceableError() 2436 << getHandle() << " expected to contain " << this->getNumResults() 2437 << " payloads but it contains " << numPayloads << " payloads"; 2438 }; 2439 2440 // Fail if there are more payload ops than results and no overflow result was 2441 // specified. 2442 if (numPayloads > getNumResults() && !getOverflowResult().has_value()) 2443 return produceNumOpsError(); 2444 2445 // Fail if there are more results than payload ops. Unless: 2446 // - "fail_on_payload_too_small" is set to "false", or 2447 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops. 2448 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() && 2449 (numPayloads != 0 || !getPassThroughEmptyHandle())) 2450 return produceNumOpsError(); 2451 2452 // Distribute payloads. 2453 SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {}); 2454 if (getOverflowResult()) 2455 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults()); 2456 2457 auto container = [&]() { 2458 if (isa<TransformHandleTypeInterface>(getHandle().getType())) { 2459 return llvm::map_to_vector( 2460 state.getPayloadOps(getHandle()), 2461 [](Operation *op) -> MappedValue { return op; }); 2462 } 2463 if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) { 2464 return llvm::map_to_vector(state.getPayloadValues(getHandle()), 2465 [](Value v) -> MappedValue { return v; }); 2466 } 2467 assert(isa<TransformParamTypeInterface>(getHandle().getType()) && 2468 "unsupported kind of transform dialect type"); 2469 return llvm::map_to_vector(state.getParams(getHandle()), 2470 [](Attribute a) -> MappedValue { return a; }); 2471 }(); 2472 2473 for (auto &&en : llvm::enumerate(container)) { 2474 int64_t resultNum = en.index(); 2475 if (resultNum >= getNumResults()) 2476 resultNum = *getOverflowResult(); 2477 resultHandles[resultNum].push_back(en.value()); 2478 } 2479 2480 // Set transform op results. 2481 for (auto &&it : llvm::enumerate(resultHandles)) 2482 results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())), 2483 it.value()); 2484 2485 return DiagnosedSilenceableFailure::success(); 2486 } 2487 2488 void transform::SplitHandleOp::getEffects( 2489 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2490 onlyReadsHandle(getHandleMutable(), effects); 2491 producesHandle(getOperation()->getOpResults(), effects); 2492 // There are no effects on the Payload IR as this is only a handle 2493 // manipulation. 2494 } 2495 2496 LogicalResult transform::SplitHandleOp::verify() { 2497 if (getOverflowResult().has_value() && 2498 !(*getOverflowResult() < getNumResults())) 2499 return emitOpError("overflow_result is not a valid result index"); 2500 2501 for (Type resultType : getResultTypes()) { 2502 if (implementSameTransformInterface(getHandle().getType(), resultType)) 2503 continue; 2504 2505 return emitOpError("expects result types to implement the same transform " 2506 "interface as the operand type"); 2507 } 2508 2509 return success(); 2510 } 2511 2512 //===----------------------------------------------------------------------===// 2513 // ReplicateOp 2514 //===----------------------------------------------------------------------===// 2515 2516 DiagnosedSilenceableFailure 2517 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter, 2518 transform::TransformResults &results, 2519 transform::TransformState &state) { 2520 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); 2521 for (const auto &en : llvm::enumerate(getHandles())) { 2522 Value handle = en.value(); 2523 if (isa<TransformHandleTypeInterface>(handle.getType())) { 2524 SmallVector<Operation *> current = 2525 llvm::to_vector(state.getPayloadOps(handle)); 2526 SmallVector<Operation *> payload; 2527 payload.reserve(numRepetitions * current.size()); 2528 for (unsigned i = 0; i < numRepetitions; ++i) 2529 llvm::append_range(payload, current); 2530 results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload); 2531 } else { 2532 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) && 2533 "expected param type"); 2534 ArrayRef<Attribute> current = state.getParams(handle); 2535 SmallVector<Attribute> params; 2536 params.reserve(numRepetitions * current.size()); 2537 for (unsigned i = 0; i < numRepetitions; ++i) 2538 llvm::append_range(params, current); 2539 results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]), 2540 params); 2541 } 2542 } 2543 return DiagnosedSilenceableFailure::success(); 2544 } 2545 2546 void transform::ReplicateOp::getEffects( 2547 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2548 onlyReadsHandle(getPatternMutable(), effects); 2549 onlyReadsHandle(getHandlesMutable(), effects); 2550 producesHandle(getOperation()->getOpResults(), effects); 2551 } 2552 2553 //===----------------------------------------------------------------------===// 2554 // SequenceOp 2555 //===----------------------------------------------------------------------===// 2556 2557 DiagnosedSilenceableFailure 2558 transform::SequenceOp::apply(transform::TransformRewriter &rewriter, 2559 transform::TransformResults &results, 2560 transform::TransformState &state) { 2561 // Map the entry block argument to the list of operations. 2562 auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 2563 if (failed(mapBlockArguments(state))) 2564 return DiagnosedSilenceableFailure::definiteFailure(); 2565 2566 return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state, 2567 results); 2568 } 2569 2570 static ParseResult parseSequenceOpOperands( 2571 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root, 2572 Type &rootType, 2573 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings, 2574 SmallVectorImpl<Type> &extraBindingTypes) { 2575 OpAsmParser::UnresolvedOperand rootOperand; 2576 OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand); 2577 if (!hasRoot.has_value()) { 2578 root = std::nullopt; 2579 return success(); 2580 } 2581 if (failed(hasRoot.value())) 2582 return failure(); 2583 root = rootOperand; 2584 2585 if (succeeded(parser.parseOptionalComma())) { 2586 if (failed(parser.parseOperandList(extraBindings))) 2587 return failure(); 2588 } 2589 if (failed(parser.parseColon())) 2590 return failure(); 2591 2592 // The paren is truly optional. 2593 (void)parser.parseOptionalLParen(); 2594 2595 if (failed(parser.parseType(rootType))) { 2596 return failure(); 2597 } 2598 2599 if (!extraBindings.empty()) { 2600 if (parser.parseComma() || parser.parseTypeList(extraBindingTypes)) 2601 return failure(); 2602 } 2603 2604 if (extraBindingTypes.size() != extraBindings.size()) { 2605 return parser.emitError(parser.getNameLoc(), 2606 "expected types to be provided for all operands"); 2607 } 2608 2609 // The paren is truly optional. 2610 (void)parser.parseOptionalRParen(); 2611 return success(); 2612 } 2613 2614 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, 2615 Value root, Type rootType, 2616 ValueRange extraBindings, 2617 TypeRange extraBindingTypes) { 2618 if (!root) 2619 return; 2620 2621 printer << root; 2622 bool hasExtras = !extraBindings.empty(); 2623 if (hasExtras) { 2624 printer << ", "; 2625 printer.printOperands(extraBindings); 2626 } 2627 2628 printer << " : "; 2629 if (hasExtras) 2630 printer << "("; 2631 2632 printer << rootType; 2633 if (hasExtras) { 2634 printer << ", "; 2635 llvm::interleaveComma(extraBindingTypes, printer.getStream()); 2636 printer << ")"; 2637 } 2638 } 2639 2640 /// Returns `true` if the given op operand may be consuming the handle value in 2641 /// the Transform IR. That is, if it may have a Free effect on it. 2642 static bool isValueUsePotentialConsumer(OpOperand &use) { 2643 // Conservatively assume the effect being present in absence of the interface. 2644 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner()); 2645 if (!iface) 2646 return true; 2647 2648 return isHandleConsumed(use.get(), iface); 2649 } 2650 2651 LogicalResult 2652 checkDoubleConsume(Value value, 2653 function_ref<InFlightDiagnostic()> reportError) { 2654 OpOperand *potentialConsumer = nullptr; 2655 for (OpOperand &use : value.getUses()) { 2656 if (!isValueUsePotentialConsumer(use)) 2657 continue; 2658 2659 if (!potentialConsumer) { 2660 potentialConsumer = &use; 2661 continue; 2662 } 2663 2664 InFlightDiagnostic diag = reportError() 2665 << " has more than one potential consumer"; 2666 diag.attachNote(potentialConsumer->getOwner()->getLoc()) 2667 << "used here as operand #" << potentialConsumer->getOperandNumber(); 2668 diag.attachNote(use.getOwner()->getLoc()) 2669 << "used here as operand #" << use.getOperandNumber(); 2670 return diag; 2671 } 2672 2673 return success(); 2674 } 2675 2676 LogicalResult transform::SequenceOp::verify() { 2677 assert(getBodyBlock()->getNumArguments() >= 1 && 2678 "the number of arguments must have been verified to be more than 1 by " 2679 "PossibleTopLevelTransformOpTrait"); 2680 2681 if (!getRoot() && !getExtraBindings().empty()) { 2682 return emitOpError() 2683 << "does not expect extra operands when used as top-level"; 2684 } 2685 2686 // Check if a block argument has more than one consuming use. 2687 for (BlockArgument arg : getBodyBlock()->getArguments()) { 2688 if (failed(checkDoubleConsume(arg, [this, arg]() { 2689 return (emitOpError() << "block argument #" << arg.getArgNumber()); 2690 }))) { 2691 return failure(); 2692 } 2693 } 2694 2695 // Check properties of the nested operations they cannot check themselves. 2696 for (Operation &child : *getBodyBlock()) { 2697 if (!isa<TransformOpInterface>(child) && 2698 &child != &getBodyBlock()->back()) { 2699 InFlightDiagnostic diag = 2700 emitOpError() 2701 << "expected children ops to implement TransformOpInterface"; 2702 diag.attachNote(child.getLoc()) << "op without interface"; 2703 return diag; 2704 } 2705 2706 for (OpResult result : child.getResults()) { 2707 auto report = [&]() { 2708 return (child.emitError() << "result #" << result.getResultNumber()); 2709 }; 2710 if (failed(checkDoubleConsume(result, report))) 2711 return failure(); 2712 } 2713 } 2714 2715 if (!getBodyBlock()->mightHaveTerminator()) 2716 return emitOpError() << "expects to have a terminator in the body"; 2717 2718 if (getBodyBlock()->getTerminator()->getOperandTypes() != 2719 getOperation()->getResultTypes()) { 2720 InFlightDiagnostic diag = emitOpError() 2721 << "expects the types of the terminator operands " 2722 "to match the types of the result"; 2723 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 2724 return diag; 2725 } 2726 return success(); 2727 } 2728 2729 void transform::SequenceOp::getEffects( 2730 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2731 getPotentialTopLevelEffects(effects); 2732 } 2733 2734 OperandRange 2735 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { 2736 assert(point == getBody() && "unexpected region index"); 2737 if (getOperation()->getNumOperands() > 0) 2738 return getOperation()->getOperands(); 2739 return OperandRange(getOperation()->operand_end(), 2740 getOperation()->operand_end()); 2741 } 2742 2743 void transform::SequenceOp::getSuccessorRegions( 2744 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 2745 if (point.isParent()) { 2746 Region *bodyRegion = &getBody(); 2747 regions.emplace_back(bodyRegion, getNumOperands() != 0 2748 ? bodyRegion->getArguments() 2749 : Block::BlockArgListType()); 2750 return; 2751 } 2752 2753 assert(point == getBody() && "unexpected region index"); 2754 regions.emplace_back(getOperation()->getResults()); 2755 } 2756 2757 void transform::SequenceOp::getRegionInvocationBounds( 2758 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 2759 (void)operands; 2760 bounds.emplace_back(1, 1); 2761 } 2762 2763 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 2764 TypeRange resultTypes, 2765 FailurePropagationMode failurePropagationMode, 2766 Value root, 2767 SequenceBodyBuilderFn bodyBuilder) { 2768 build(builder, state, resultTypes, failurePropagationMode, root, 2769 /*extra_bindings=*/ValueRange()); 2770 Type bbArgType = root.getType(); 2771 buildSequenceBody(builder, state, bbArgType, 2772 /*extraBindingTypes=*/TypeRange(), bodyBuilder); 2773 } 2774 2775 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 2776 TypeRange resultTypes, 2777 FailurePropagationMode failurePropagationMode, 2778 Value root, ValueRange extraBindings, 2779 SequenceBodyBuilderArgsFn bodyBuilder) { 2780 build(builder, state, resultTypes, failurePropagationMode, root, 2781 extraBindings); 2782 buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(), 2783 bodyBuilder); 2784 } 2785 2786 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 2787 TypeRange resultTypes, 2788 FailurePropagationMode failurePropagationMode, 2789 Type bbArgType, 2790 SequenceBodyBuilderFn bodyBuilder) { 2791 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), 2792 /*extra_bindings=*/ValueRange()); 2793 buildSequenceBody(builder, state, bbArgType, 2794 /*extraBindingTypes=*/TypeRange(), bodyBuilder); 2795 } 2796 2797 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 2798 TypeRange resultTypes, 2799 FailurePropagationMode failurePropagationMode, 2800 Type bbArgType, TypeRange extraBindingTypes, 2801 SequenceBodyBuilderArgsFn bodyBuilder) { 2802 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), 2803 /*extra_bindings=*/ValueRange()); 2804 buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder); 2805 } 2806 2807 //===----------------------------------------------------------------------===// 2808 // PrintOp 2809 //===----------------------------------------------------------------------===// 2810 2811 void transform::PrintOp::build(OpBuilder &builder, OperationState &result, 2812 StringRef name) { 2813 if (!name.empty()) 2814 result.getOrAddProperties<Properties>().name = builder.getStringAttr(name); 2815 } 2816 2817 void transform::PrintOp::build(OpBuilder &builder, OperationState &result, 2818 Value target, StringRef name) { 2819 result.addOperands({target}); 2820 build(builder, result, name); 2821 } 2822 2823 DiagnosedSilenceableFailure 2824 transform::PrintOp::apply(transform::TransformRewriter &rewriter, 2825 transform::TransformResults &results, 2826 transform::TransformState &state) { 2827 llvm::outs() << "[[[ IR printer: "; 2828 if (getName().has_value()) 2829 llvm::outs() << *getName() << " "; 2830 2831 OpPrintingFlags printFlags; 2832 if (getAssumeVerified().value_or(false)) 2833 printFlags.assumeVerified(); 2834 if (getUseLocalScope().value_or(false)) 2835 printFlags.useLocalScope(); 2836 if (getSkipRegions().value_or(false)) 2837 printFlags.skipRegions(); 2838 2839 if (!getTarget()) { 2840 llvm::outs() << "top-level ]]]\n"; 2841 state.getTopLevel()->print(llvm::outs(), printFlags); 2842 llvm::outs() << "\n"; 2843 llvm::outs().flush(); 2844 return DiagnosedSilenceableFailure::success(); 2845 } 2846 2847 llvm::outs() << "]]]\n"; 2848 for (Operation *target : state.getPayloadOps(getTarget())) { 2849 target->print(llvm::outs(), printFlags); 2850 llvm::outs() << "\n"; 2851 } 2852 2853 llvm::outs().flush(); 2854 return DiagnosedSilenceableFailure::success(); 2855 } 2856 2857 void transform::PrintOp::getEffects( 2858 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2859 // We don't really care about mutability here, but `getTarget` now 2860 // unconditionally casts to a specific type before verification could run 2861 // here. 2862 if (!getTargetMutable().empty()) 2863 onlyReadsHandle(getTargetMutable()[0], effects); 2864 onlyReadsPayload(effects); 2865 2866 // There is no resource for stderr file descriptor, so just declare print 2867 // writes into the default resource. 2868 effects.emplace_back(MemoryEffects::Write::get()); 2869 } 2870 2871 //===----------------------------------------------------------------------===// 2872 // VerifyOp 2873 //===----------------------------------------------------------------------===// 2874 2875 DiagnosedSilenceableFailure 2876 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter, 2877 Operation *target, 2878 transform::ApplyToEachResultList &results, 2879 transform::TransformState &state) { 2880 if (failed(::mlir::verify(target))) { 2881 DiagnosedDefiniteFailure diag = emitDefiniteFailure() 2882 << "failed to verify payload op"; 2883 diag.attachNote(target->getLoc()) << "payload op"; 2884 return diag; 2885 } 2886 return DiagnosedSilenceableFailure::success(); 2887 } 2888 2889 void transform::VerifyOp::getEffects( 2890 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2891 transform::onlyReadsHandle(getTargetMutable(), effects); 2892 } 2893 2894 //===----------------------------------------------------------------------===// 2895 // YieldOp 2896 //===----------------------------------------------------------------------===// 2897 2898 void transform::YieldOp::getEffects( 2899 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2900 onlyReadsHandle(getOperandsMutable(), effects); 2901 } 2902