1 //===- TransformOps.cpp - Transform dialect operations --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformOps.h" 10 11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" 15 #include "mlir/Dialect/Transform/IR/TransformAttrs.h" 16 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 18 #include "mlir/Dialect/Transform/IR/TransformTypes.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/Dominance.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/Verifier.h" 23 #include "mlir/Interfaces/ControlFlowInterfaces.h" 24 #include "mlir/Interfaces/FunctionImplementation.h" 25 #include "mlir/Interfaces/FunctionInterfaces.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Pass/PassManager.h" 28 #include "mlir/Pass/PassRegistry.h" 29 #include "mlir/Transforms/CSE.h" 30 #include "mlir/Transforms/DialectConversion.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" 33 #include "llvm/ADT/STLExtras.h" 34 #include "llvm/ADT/ScopeExit.h" 35 #include "llvm/ADT/SmallPtrSet.h" 36 #include "llvm/ADT/TypeSwitch.h" 37 #include "llvm/Support/Debug.h" 38 #include <optional> 39 40 #define DEBUG_TYPE "transform-dialect" 41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 42 43 #define DEBUG_TYPE_MATCHER "transform-matcher" 44 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") 45 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) 46 47 using namespace mlir; 48 49 static ParseResult parseSequenceOpOperands( 50 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root, 51 Type &rootType, 52 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings, 53 SmallVectorImpl<Type> &extraBindingTypes); 54 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, 55 Value root, Type rootType, 56 ValueRange extraBindings, 57 TypeRange extraBindingTypes); 58 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, 59 ArrayAttr matchers, ArrayAttr actions); 60 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, 61 ArrayAttr &matchers, 62 ArrayAttr &actions); 63 64 /// Helper function to check if the given transform op is contained in (or 65 /// equal to) the given payload target op. In that case, an error is returned. 66 /// Transforming transform IR that is currently executing is generally unsafe. 67 static DiagnosedSilenceableFailure 68 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, 69 Operation *payload) { 70 Operation *transformAncestor = transform.getOperation(); 71 while (transformAncestor) { 72 if (transformAncestor == payload) { 73 DiagnosedDefiniteFailure diag = 74 transform.emitDefiniteFailure() 75 << "cannot apply transform to itself (or one of its ancestors)"; 76 diag.attachNote(payload->getLoc()) << "target payload op"; 77 return diag; 78 } 79 transformAncestor = transformAncestor->getParentOp(); 80 } 81 return DiagnosedSilenceableFailure::success(); 82 } 83 84 #define GET_OP_CLASSES 85 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 86 87 //===----------------------------------------------------------------------===// 88 // AlternativesOp 89 //===----------------------------------------------------------------------===// 90 91 OperandRange 92 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { 93 if (!point.isParent() && getOperation()->getNumOperands() == 1) 94 return getOperation()->getOperands(); 95 return OperandRange(getOperation()->operand_end(), 96 getOperation()->operand_end()); 97 } 98 99 void transform::AlternativesOp::getSuccessorRegions( 100 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 101 for (Region &alternative : llvm::drop_begin( 102 getAlternatives(), 103 point.isParent() ? 0 104 : point.getRegionOrNull()->getRegionNumber() + 1)) { 105 regions.emplace_back(&alternative, !getOperands().empty() 106 ? alternative.getArguments() 107 : Block::BlockArgListType()); 108 } 109 if (!point.isParent()) 110 regions.emplace_back(getOperation()->getResults()); 111 } 112 113 void transform::AlternativesOp::getRegionInvocationBounds( 114 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 115 (void)operands; 116 // The region corresponding to the first alternative is always executed, the 117 // remaining may or may not be executed. 118 bounds.reserve(getNumRegions()); 119 bounds.emplace_back(1, 1); 120 bounds.resize(getNumRegions(), InvocationBounds(0, 1)); 121 } 122 123 static void forwardEmptyOperands(Block *block, transform::TransformState &state, 124 transform::TransformResults &results) { 125 for (const auto &res : block->getParentOp()->getOpResults()) 126 results.set(res, {}); 127 } 128 129 DiagnosedSilenceableFailure 130 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, 131 transform::TransformResults &results, 132 transform::TransformState &state) { 133 SmallVector<Operation *> originals; 134 if (Value scopeHandle = getScope()) 135 llvm::append_range(originals, state.getPayloadOps(scopeHandle)); 136 else 137 originals.push_back(state.getTopLevel()); 138 139 for (Operation *original : originals) { 140 if (original->isAncestor(getOperation())) { 141 auto diag = emitDefiniteFailure() 142 << "scope must not contain the transforms being applied"; 143 diag.attachNote(original->getLoc()) << "scope"; 144 return diag; 145 } 146 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 147 auto diag = emitDefiniteFailure() 148 << "only isolated-from-above ops can be alternative scopes"; 149 diag.attachNote(original->getLoc()) << "scope"; 150 return diag; 151 } 152 } 153 154 for (Region ® : getAlternatives()) { 155 // Clone the scope operations and make the transforms in this alternative 156 // region apply to them by virtue of mapping the block argument (the only 157 // visible handle) to the cloned scope operations. This effectively prevents 158 // the transformation from accessing any IR outside the scope. 159 auto scope = state.make_region_scope(reg); 160 auto clones = llvm::to_vector( 161 llvm::map_range(originals, [](Operation *op) { return op->clone(); })); 162 auto deleteClones = llvm::make_scope_exit([&] { 163 for (Operation *clone : clones) 164 clone->erase(); 165 }); 166 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) 167 return DiagnosedSilenceableFailure::definiteFailure(); 168 169 bool failed = false; 170 for (Operation &transform : reg.front().without_terminator()) { 171 DiagnosedSilenceableFailure result = 172 state.applyTransform(cast<TransformOpInterface>(transform)); 173 if (result.isSilenceableFailure()) { 174 LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() 175 << "\n"); 176 failed = true; 177 break; 178 } 179 180 if (::mlir::failed(result.silence())) 181 return DiagnosedSilenceableFailure::definiteFailure(); 182 } 183 184 // If all operations in the given alternative succeeded, no need to consider 185 // the rest. Replace the original scoping operation with the clone on which 186 // the transformations were performed. 187 if (!failed) { 188 // We will be using the clones, so cancel their scheduled deletion. 189 deleteClones.release(); 190 TrackingListener listener(state, *this); 191 IRRewriter rewriter(getContext(), &listener); 192 for (const auto &kvp : llvm::zip(originals, clones)) { 193 Operation *original = std::get<0>(kvp); 194 Operation *clone = std::get<1>(kvp); 195 original->getBlock()->getOperations().insert(original->getIterator(), 196 clone); 197 rewriter.replaceOp(original, clone->getResults()); 198 } 199 detail::forwardTerminatorOperands(®.front(), state, results); 200 return DiagnosedSilenceableFailure::success(); 201 } 202 } 203 return emitSilenceableError() << "all alternatives failed"; 204 } 205 206 void transform::AlternativesOp::getEffects( 207 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 208 consumesHandle(getOperands(), effects); 209 producesHandle(getResults(), effects); 210 for (Region *region : getRegions()) { 211 if (!region->empty()) 212 producesHandle(region->front().getArguments(), effects); 213 } 214 modifiesPayload(effects); 215 } 216 217 LogicalResult transform::AlternativesOp::verify() { 218 for (Region &alternative : getAlternatives()) { 219 Block &block = alternative.front(); 220 Operation *terminator = block.getTerminator(); 221 if (terminator->getOperands().getTypes() != getResults().getTypes()) { 222 InFlightDiagnostic diag = emitOpError() 223 << "expects terminator operands to have the " 224 "same type as results of the operation"; 225 diag.attachNote(terminator->getLoc()) << "terminator"; 226 return diag; 227 } 228 } 229 230 return success(); 231 } 232 233 //===----------------------------------------------------------------------===// 234 // AnnotateOp 235 //===----------------------------------------------------------------------===// 236 237 DiagnosedSilenceableFailure 238 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter, 239 transform::TransformResults &results, 240 transform::TransformState &state) { 241 SmallVector<Operation *> targets = 242 llvm::to_vector(state.getPayloadOps(getTarget())); 243 244 Attribute attr = UnitAttr::get(getContext()); 245 if (auto paramH = getParam()) { 246 ArrayRef<Attribute> params = state.getParams(paramH); 247 if (params.size() != 1) { 248 if (targets.size() != params.size()) { 249 return emitSilenceableError() 250 << "parameter and target have different payload lengths (" 251 << params.size() << " vs " << targets.size() << ")"; 252 } 253 for (auto &&[target, attr] : llvm::zip_equal(targets, params)) 254 target->setAttr(getName(), attr); 255 return DiagnosedSilenceableFailure::success(); 256 } 257 attr = params[0]; 258 } 259 for (auto target : targets) 260 target->setAttr(getName(), attr); 261 return DiagnosedSilenceableFailure::success(); 262 } 263 264 void transform::AnnotateOp::getEffects( 265 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 266 onlyReadsHandle(getTarget(), effects); 267 onlyReadsHandle(getParam(), effects); 268 modifiesPayload(effects); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // ApplyCommonSubexpressionEliminationOp 273 //===----------------------------------------------------------------------===// 274 275 DiagnosedSilenceableFailure 276 transform::ApplyCommonSubexpressionEliminationOp::applyToOne( 277 transform::TransformRewriter &rewriter, Operation *target, 278 ApplyToEachResultList &results, transform::TransformState &state) { 279 // Make sure that this transform is not applied to itself. Modifying the 280 // transform IR while it is being interpreted is generally dangerous. 281 DiagnosedSilenceableFailure payloadCheck = 282 ensurePayloadIsSeparateFromTransform(*this, target); 283 if (!payloadCheck.succeeded()) 284 return payloadCheck; 285 286 DominanceInfo domInfo; 287 mlir::eliminateCommonSubExpressions(rewriter, domInfo, target); 288 return DiagnosedSilenceableFailure::success(); 289 } 290 291 void transform::ApplyCommonSubexpressionEliminationOp::getEffects( 292 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 293 transform::onlyReadsHandle(getTarget(), effects); 294 transform::modifiesPayload(effects); 295 } 296 297 //===----------------------------------------------------------------------===// 298 // ApplyDeadCodeEliminationOp 299 //===----------------------------------------------------------------------===// 300 301 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne( 302 transform::TransformRewriter &rewriter, Operation *target, 303 ApplyToEachResultList &results, transform::TransformState &state) { 304 // Make sure that this transform is not applied to itself. Modifying the 305 // transform IR while it is being interpreted is generally dangerous. 306 DiagnosedSilenceableFailure payloadCheck = 307 ensurePayloadIsSeparateFromTransform(*this, target); 308 if (!payloadCheck.succeeded()) 309 return payloadCheck; 310 311 // Maintain a worklist of potentially dead ops. 312 SetVector<Operation *> worklist; 313 314 // Helper function that adds all defining ops of used values (operands and 315 // operands of nested ops). 316 auto addDefiningOpsToWorklist = [&](Operation *op) { 317 op->walk([&](Operation *op) { 318 for (Value v : op->getOperands()) 319 if (Operation *defOp = v.getDefiningOp()) 320 if (target->isProperAncestor(defOp)) 321 worklist.insert(defOp); 322 }); 323 }; 324 325 // Helper function that erases an op. 326 auto eraseOp = [&](Operation *op) { 327 // Remove op and nested ops from the worklist. 328 op->walk([&](Operation *op) { 329 auto it = llvm::find(worklist, op); 330 if (it != worklist.end()) 331 worklist.erase(it); 332 }); 333 rewriter.eraseOp(op); 334 }; 335 336 // Initial walk over the IR. 337 target->walk<WalkOrder::PostOrder>([&](Operation *op) { 338 if (op != target && isOpTriviallyDead(op)) { 339 addDefiningOpsToWorklist(op); 340 eraseOp(op); 341 } 342 }); 343 344 // Erase all ops that have become dead. 345 while (!worklist.empty()) { 346 Operation *op = worklist.pop_back_val(); 347 if (!isOpTriviallyDead(op)) 348 continue; 349 addDefiningOpsToWorklist(op); 350 eraseOp(op); 351 } 352 353 return DiagnosedSilenceableFailure::success(); 354 } 355 356 void transform::ApplyDeadCodeEliminationOp::getEffects( 357 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 358 transform::onlyReadsHandle(getTarget(), effects); 359 transform::modifiesPayload(effects); 360 } 361 362 //===----------------------------------------------------------------------===// 363 // ApplyPatternsOp 364 //===----------------------------------------------------------------------===// 365 366 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( 367 transform::TransformRewriter &rewriter, Operation *target, 368 ApplyToEachResultList &results, transform::TransformState &state) { 369 // Make sure that this transform is not applied to itself. Modifying the 370 // transform IR while it is being interpreted is generally dangerous. Even 371 // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver 372 // performs many additional simplifications such as dead code elimination. 373 DiagnosedSilenceableFailure payloadCheck = 374 ensurePayloadIsSeparateFromTransform(*this, target); 375 if (!payloadCheck.succeeded()) 376 return payloadCheck; 377 378 // Gather all specified patterns. 379 MLIRContext *ctx = target->getContext(); 380 RewritePatternSet patterns(ctx); 381 if (!getRegion().empty()) { 382 for (Operation &op : getRegion().front()) { 383 cast<transform::PatternDescriptorOpInterface>(&op) 384 .populatePatternsWithState(patterns, state); 385 } 386 } 387 388 // Configure the GreedyPatternRewriteDriver. 389 GreedyRewriteConfig config; 390 config.listener = 391 static_cast<RewriterBase::Listener *>(rewriter.getListener()); 392 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 393 394 // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE 395 // was requested, apply the greedy pattern rewrite only once. (The greedy 396 // pattern rewrite driver already iterates to a fixpoint internally.) 397 bool cseChanged = false; 398 // One or two iterations should be sufficient. Stop iterating after a certain 399 // threshold to make debugging easier. 400 static const int64_t kNumMaxIterations = 50; 401 int64_t iteration = 0; 402 do { 403 LogicalResult result = failure(); 404 if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 405 // Op is isolated from above. Apply patterns and also perform region 406 // simplification. 407 result = applyPatternsAndFoldGreedily(target, frozenPatterns, config); 408 } else { 409 // Manually gather list of ops because the other 410 // GreedyPatternRewriteDriver overloads only accepts ops that are isolated 411 // from above. This way, patterns can be applied to ops that are not 412 // isolated from above. Regions are not being simplified. Furthermore, 413 // only a single greedy rewrite iteration is performed. 414 SmallVector<Operation *> ops; 415 target->walk([&](Operation *nestedOp) { 416 if (target != nestedOp) 417 ops.push_back(nestedOp); 418 }); 419 result = applyOpPatternsAndFold(ops, frozenPatterns, config); 420 } 421 422 // A failure typically indicates that the pattern application did not 423 // converge. 424 if (failed(result)) { 425 return emitSilenceableFailure(target) 426 << "greedy pattern application failed"; 427 } 428 429 if (getApplyCse()) { 430 DominanceInfo domInfo; 431 mlir::eliminateCommonSubExpressions(rewriter, domInfo, target, 432 &cseChanged); 433 } 434 } while (cseChanged && ++iteration < kNumMaxIterations); 435 436 if (iteration == kNumMaxIterations) 437 return emitDefiniteFailure() << "fixpoint iteration did not converge"; 438 439 return DiagnosedSilenceableFailure::success(); 440 } 441 442 LogicalResult transform::ApplyPatternsOp::verify() { 443 if (!getRegion().empty()) { 444 for (Operation &op : getRegion().front()) { 445 if (!isa<transform::PatternDescriptorOpInterface>(&op)) { 446 InFlightDiagnostic diag = emitOpError() 447 << "expected children ops to implement " 448 "PatternDescriptorOpInterface"; 449 diag.attachNote(op.getLoc()) << "op without interface"; 450 return diag; 451 } 452 } 453 } 454 return success(); 455 } 456 457 void transform::ApplyPatternsOp::getEffects( 458 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 459 transform::onlyReadsHandle(getTarget(), effects); 460 transform::modifiesPayload(effects); 461 } 462 463 void transform::ApplyPatternsOp::build( 464 OpBuilder &builder, OperationState &result, Value target, 465 function_ref<void(OpBuilder &, Location)> bodyBuilder) { 466 result.addOperands(target); 467 468 OpBuilder::InsertionGuard g(builder); 469 Region *region = result.addRegion(); 470 builder.createBlock(region); 471 if (bodyBuilder) 472 bodyBuilder(builder, result.location); 473 } 474 475 //===----------------------------------------------------------------------===// 476 // ApplyCanonicalizationPatternsOp 477 //===----------------------------------------------------------------------===// 478 479 void transform::ApplyCanonicalizationPatternsOp::populatePatterns( 480 RewritePatternSet &patterns) { 481 MLIRContext *ctx = patterns.getContext(); 482 for (Dialect *dialect : ctx->getLoadedDialects()) 483 dialect->getCanonicalizationPatterns(patterns); 484 for (RegisteredOperationName op : ctx->getRegisteredOperations()) 485 op.getCanonicalizationPatterns(patterns, ctx); 486 } 487 488 //===----------------------------------------------------------------------===// 489 // ApplyConversionPatternsOp 490 //===----------------------------------------------------------------------===// 491 492 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( 493 transform::TransformRewriter &rewriter, 494 transform::TransformResults &results, transform::TransformState &state) { 495 MLIRContext *ctx = getContext(); 496 497 // Instantiate the default type converter if a type converter builder is 498 // specified. 499 std::unique_ptr<TypeConverter> defaultTypeConverter; 500 transform::TypeConverterBuilderOpInterface typeConverterBuilder = 501 getDefaultTypeConverter(); 502 if (typeConverterBuilder) 503 defaultTypeConverter = typeConverterBuilder.getTypeConverter(); 504 505 // Configure conversion target. 506 ConversionTarget conversionTarget(*getContext()); 507 if (getLegalOps()) 508 for (Attribute attr : cast<ArrayAttr>(*getLegalOps())) 509 conversionTarget.addLegalOp( 510 OperationName(cast<StringAttr>(attr).getValue(), ctx)); 511 if (getIllegalOps()) 512 for (Attribute attr : cast<ArrayAttr>(*getIllegalOps())) 513 conversionTarget.addIllegalOp( 514 OperationName(cast<StringAttr>(attr).getValue(), ctx)); 515 if (getLegalDialects()) 516 for (Attribute attr : cast<ArrayAttr>(*getLegalDialects())) 517 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue()); 518 if (getIllegalDialects()) 519 for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects())) 520 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue()); 521 522 // Gather all specified patterns. 523 RewritePatternSet patterns(ctx); 524 // Need to keep the converters alive until after pattern application because 525 // the patterns take a reference to an object that would otherwise get out of 526 // scope. 527 SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters; 528 if (!getPatterns().empty()) { 529 for (Operation &op : getPatterns().front()) { 530 auto descriptor = 531 cast<transform::ConversionPatternDescriptorOpInterface>(&op); 532 533 // Check if this pattern set specifies a type converter. 534 std::unique_ptr<TypeConverter> typeConverter = 535 descriptor.getTypeConverter(); 536 TypeConverter *converter = nullptr; 537 if (typeConverter) { 538 keepAliveConverters.emplace_back(std::move(typeConverter)); 539 converter = keepAliveConverters.back().get(); 540 } else { 541 // No type converter specified: Use the default type converter. 542 if (!defaultTypeConverter) { 543 auto diag = emitDefiniteFailure() 544 << "pattern descriptor does not specify type " 545 "converter and apply_conversion_patterns op has " 546 "no default type converter"; 547 diag.attachNote(op.getLoc()) << "pattern descriptor op"; 548 return diag; 549 } 550 converter = defaultTypeConverter.get(); 551 } 552 553 // Add descriptor-specific updates to the conversion target, which may 554 // depend on the final type converter. In structural converters, the 555 // legality of types dictates the dynamic legality of an operation. 556 descriptor.populateConversionTargetRules(*converter, conversionTarget); 557 558 descriptor.populatePatterns(*converter, patterns); 559 } 560 } 561 562 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 563 for (Operation *target : state.getPayloadOps(getTarget())) { 564 // Make sure that this transform is not applied to itself. Modifying the 565 // transform IR while it is being interpreted is generally dangerous. 566 DiagnosedSilenceableFailure payloadCheck = 567 ensurePayloadIsSeparateFromTransform(*this, target); 568 if (!payloadCheck.succeeded()) 569 return payloadCheck; 570 571 LogicalResult status = failure(); 572 if (getPartialConversion()) { 573 status = applyPartialConversion(target, conversionTarget, frozenPatterns); 574 } else { 575 status = applyFullConversion(target, conversionTarget, frozenPatterns); 576 } 577 578 if (failed(status)) { 579 auto diag = emitSilenceableError() << "dialect conversion failed"; 580 diag.attachNote(target->getLoc()) << "target op"; 581 return diag; 582 } 583 } 584 585 return DiagnosedSilenceableFailure::success(); 586 } 587 588 LogicalResult transform::ApplyConversionPatternsOp::verify() { 589 if (getNumRegions() != 1 && getNumRegions() != 2) 590 return emitOpError() << "expected 1 or 2 regions"; 591 if (!getPatterns().empty()) { 592 for (Operation &op : getPatterns().front()) { 593 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) { 594 InFlightDiagnostic diag = 595 emitOpError() << "expected pattern children ops to implement " 596 "ConversionPatternDescriptorOpInterface"; 597 diag.attachNote(op.getLoc()) << "op without interface"; 598 return diag; 599 } 600 } 601 } 602 if (getNumRegions() == 2) { 603 Region &typeConverterRegion = getRegion(1); 604 if (!llvm::hasSingleElement(typeConverterRegion.front())) 605 return emitOpError() 606 << "expected exactly one op in default type converter region"; 607 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>( 608 &typeConverterRegion.front().front()); 609 if (!typeConverterOp) { 610 InFlightDiagnostic diag = emitOpError() 611 << "expected default converter child op to " 612 "implement TypeConverterBuilderOpInterface"; 613 diag.attachNote(typeConverterOp->getLoc()) << "op without interface"; 614 return diag; 615 } 616 // Check default type converter type. 617 if (!getPatterns().empty()) { 618 for (Operation &op : getPatterns().front()) { 619 auto descriptor = 620 cast<transform::ConversionPatternDescriptorOpInterface>(&op); 621 if (failed(descriptor.verifyTypeConverter(typeConverterOp))) 622 return failure(); 623 } 624 } 625 } 626 return success(); 627 } 628 629 void transform::ApplyConversionPatternsOp::getEffects( 630 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 631 transform::consumesHandle(getTarget(), effects); 632 transform::modifiesPayload(effects); 633 } 634 635 void transform::ApplyConversionPatternsOp::build( 636 OpBuilder &builder, OperationState &result, Value target, 637 function_ref<void(OpBuilder &, Location)> patternsBodyBuilder, 638 function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) { 639 result.addOperands(target); 640 641 { 642 OpBuilder::InsertionGuard g(builder); 643 Region *region1 = result.addRegion(); 644 builder.createBlock(region1); 645 if (patternsBodyBuilder) 646 patternsBodyBuilder(builder, result.location); 647 } 648 { 649 OpBuilder::InsertionGuard g(builder); 650 Region *region2 = result.addRegion(); 651 builder.createBlock(region2); 652 if (typeConverterBodyBuilder) 653 typeConverterBodyBuilder(builder, result.location); 654 } 655 } 656 657 //===----------------------------------------------------------------------===// 658 // ApplyToLLVMConversionPatternsOp 659 //===----------------------------------------------------------------------===// 660 661 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns( 662 TypeConverter &typeConverter, RewritePatternSet &patterns) { 663 Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); 664 assert(dialect && "expected that dialect is loaded"); 665 auto iface = cast<ConvertToLLVMPatternInterface>(dialect); 666 // ConversionTarget is currently ignored because the enclosing 667 // apply_conversion_patterns op sets up its own ConversionTarget. 668 ConversionTarget target(*getContext()); 669 iface->populateConvertToLLVMConversionPatterns( 670 target, static_cast<LLVMTypeConverter &>(typeConverter), patterns); 671 } 672 673 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter( 674 transform::TypeConverterBuilderOpInterface builder) { 675 if (builder.getTypeConverterType() != "LLVMTypeConverter") 676 return emitOpError("expected LLVMTypeConverter"); 677 return success(); 678 } 679 680 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() { 681 Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); 682 if (!dialect) 683 return emitOpError("unknown dialect or dialect not loaded: ") 684 << getDialectName(); 685 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); 686 if (!iface) 687 return emitOpError( 688 "dialect does not implement ConvertToLLVMPatternInterface or " 689 "extension was not loaded: ") 690 << getDialectName(); 691 return success(); 692 } 693 694 //===----------------------------------------------------------------------===// 695 // ApplyLoopInvariantCodeMotionOp 696 //===----------------------------------------------------------------------===// 697 698 DiagnosedSilenceableFailure 699 transform::ApplyLoopInvariantCodeMotionOp::applyToOne( 700 transform::TransformRewriter &rewriter, LoopLikeOpInterface target, 701 transform::ApplyToEachResultList &results, 702 transform::TransformState &state) { 703 // Currently, LICM does not remove operations, so we don't need tracking. 704 // If this ever changes, add a LICM entry point that takes a rewriter. 705 moveLoopInvariantCode(target); 706 return DiagnosedSilenceableFailure::success(); 707 } 708 709 void transform::ApplyLoopInvariantCodeMotionOp::getEffects( 710 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 711 transform::onlyReadsHandle(getTarget(), effects); 712 transform::modifiesPayload(effects); 713 } 714 715 //===----------------------------------------------------------------------===// 716 // ApplyRegisteredPassOp 717 //===----------------------------------------------------------------------===// 718 719 DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne( 720 transform::TransformRewriter &rewriter, Operation *target, 721 ApplyToEachResultList &results, transform::TransformState &state) { 722 // Make sure that this transform is not applied to itself. Modifying the 723 // transform IR while it is being interpreted is generally dangerous. Even 724 // more so when applying passes because they may perform a wide range of IR 725 // modifications. 726 DiagnosedSilenceableFailure payloadCheck = 727 ensurePayloadIsSeparateFromTransform(*this, target); 728 if (!payloadCheck.succeeded()) 729 return payloadCheck; 730 731 // Get pass or pass pipeline from registry. 732 const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName()); 733 if (!info) 734 info = PassInfo::lookup(getPassName()); 735 if (!info) 736 return emitDefiniteFailure() 737 << "unknown pass or pass pipeline: " << getPassName(); 738 739 // Create pass manager and run the pass or pass pipeline. 740 PassManager pm(getContext()); 741 if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) { 742 emitError(msg); 743 return failure(); 744 }))) { 745 return emitDefiniteFailure() 746 << "failed to add pass or pass pipeline to pipeline: " 747 << getPassName(); 748 } 749 if (failed(pm.run(target))) { 750 auto diag = emitSilenceableError() << "pass pipeline failed"; 751 diag.attachNote(target->getLoc()) << "target op"; 752 return diag; 753 } 754 755 results.push_back(target); 756 return DiagnosedSilenceableFailure::success(); 757 } 758 759 //===----------------------------------------------------------------------===// 760 // CastOp 761 //===----------------------------------------------------------------------===// 762 763 DiagnosedSilenceableFailure 764 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter, 765 Operation *target, ApplyToEachResultList &results, 766 transform::TransformState &state) { 767 results.push_back(target); 768 return DiagnosedSilenceableFailure::success(); 769 } 770 771 void transform::CastOp::getEffects( 772 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 773 onlyReadsPayload(effects); 774 onlyReadsHandle(getInput(), effects); 775 producesHandle(getOutput(), effects); 776 } 777 778 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 779 assert(inputs.size() == 1 && "expected one input"); 780 assert(outputs.size() == 1 && "expected one output"); 781 return llvm::all_of( 782 std::initializer_list<Type>{inputs.front(), outputs.front()}, 783 [](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); }); 784 } 785 786 //===----------------------------------------------------------------------===// 787 // CollectMatchingOp 788 //===----------------------------------------------------------------------===// 789 790 /// Applies matcher operations from the given `block` assigning `op` as the 791 /// payload of the block's first argument. Updates `state` accordingly. If any 792 /// of the matcher produces a silenceable failure, discards it (printing the 793 /// content to the debug output stream) and returns failure. If any of the 794 /// matchers produces a definite failure, reports it and returns failure. If all 795 /// matchers in the block succeed, populates `mappings` with the payload 796 /// entities associated with the block terminator operands. 797 static DiagnosedSilenceableFailure 798 matchBlock(Block &block, Operation *op, transform::TransformState &state, 799 SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) { 800 assert(block.getParent() && "cannot match using a detached block"); 801 auto matchScope = state.make_region_scope(*block.getParent()); 802 if (failed(state.mapBlockArgument(block.getArgument(0), {op}))) 803 return DiagnosedSilenceableFailure::definiteFailure(); 804 805 for (Operation &match : block.without_terminator()) { 806 if (!isa<transform::MatchOpInterface>(match)) { 807 return emitDefiniteFailure(match.getLoc()) 808 << "expected operations in the match part to " 809 "implement MatchOpInterface"; 810 } 811 DiagnosedSilenceableFailure diag = 812 state.applyTransform(cast<transform::TransformOpInterface>(match)); 813 if (diag.succeeded()) 814 continue; 815 816 return diag; 817 } 818 819 // Remember the values mapped to the terminator operands so we can 820 // forward them to the action. 821 ValueRange yieldedValues = block.getTerminator()->getOperands(); 822 transform::detail::prepareValueMappings(mappings, yieldedValues, state); 823 return DiagnosedSilenceableFailure::success(); 824 } 825 826 /// Returns `true` if both types implement one of the interfaces provided as 827 /// template parameters. 828 template <typename... Tys> 829 static bool implementSameInterface(Type t1, Type t2) { 830 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false); 831 } 832 833 /// Returns `true` if both types implement one of the transform dialect 834 /// interfaces. 835 static bool implementSameTransformInterface(Type t1, Type t2) { 836 return implementSameInterface<transform::TransformHandleTypeInterface, 837 transform::TransformParamTypeInterface, 838 transform::TransformValueHandleTypeInterface>( 839 t1, t2); 840 } 841 842 //===----------------------------------------------------------------------===// 843 // CollectMatchingOp 844 //===----------------------------------------------------------------------===// 845 846 DiagnosedSilenceableFailure 847 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, 848 transform::TransformResults &results, 849 transform::TransformState &state) { 850 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>( 851 getOperation(), getMatcher()); 852 if (matcher.isExternal()) { 853 return emitDefiniteFailure() 854 << "unresolved external symbol " << getMatcher(); 855 } 856 857 SmallVector<SmallVector<MappedValue>, 2> rawResults; 858 rawResults.resize(getOperation()->getNumResults()); 859 std::optional<DiagnosedSilenceableFailure> maybeFailure; 860 for (Operation *root : state.getPayloadOps(getRoot())) { 861 WalkResult walkResult = root->walk([&](Operation *op) { 862 DEBUG_MATCHER({ 863 DBGS_MATCHER() << "matching "; 864 op->print(llvm::dbgs(), 865 OpPrintingFlags().assumeVerified().skipRegions()); 866 llvm::dbgs() << " @" << op << "\n"; 867 }); 868 869 // Try matching. 870 SmallVector<SmallVector<MappedValue>> mappings; 871 DiagnosedSilenceableFailure diag = 872 matchBlock(matcher.getFunctionBody().front(), op, state, mappings); 873 if (diag.isDefiniteFailure()) 874 return WalkResult::interrupt(); 875 if (diag.isSilenceableFailure()) { 876 DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() 877 << " failed: " << diag.getMessage()); 878 return WalkResult::advance(); 879 } 880 881 // If succeeded, collect results. 882 for (auto &&[i, mapping] : llvm::enumerate(mappings)) { 883 if (mapping.size() != 1) { 884 maybeFailure.emplace(emitSilenceableError() 885 << "result #" << i << ", associated with " 886 << mapping.size() 887 << " payload objects, expected 1"); 888 return WalkResult::interrupt(); 889 } 890 rawResults[i].push_back(mapping[0]); 891 } 892 return WalkResult::advance(); 893 }); 894 if (walkResult.wasInterrupted()) 895 return std::move(*maybeFailure); 896 assert(!maybeFailure && "failure set but the walk was not interrupted"); 897 898 for (auto &&[opResult, rawResult] : 899 llvm::zip_equal(getOperation()->getResults(), rawResults)) { 900 results.setMappedValues(opResult, rawResult); 901 } 902 } 903 return DiagnosedSilenceableFailure::success(); 904 } 905 906 void transform::CollectMatchingOp::getEffects( 907 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 908 onlyReadsHandle(getRoot(), effects); 909 producesHandle(getResults(), effects); 910 onlyReadsPayload(effects); 911 } 912 913 LogicalResult transform::CollectMatchingOp::verifySymbolUses( 914 SymbolTableCollection &symbolTable) { 915 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( 916 symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher())); 917 if (!matcherSymbol || 918 !isa<TransformOpInterface>(matcherSymbol.getOperation())) 919 return emitError() << "unresolved matcher symbol " << getMatcher(); 920 921 ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes(); 922 if (argumentTypes.size() != 1 || 923 !isa<TransformHandleTypeInterface>(argumentTypes[0])) { 924 return emitError() 925 << "expected the matcher to take one operation handle argument"; 926 } 927 if (!matcherSymbol.getArgAttr( 928 0, transform::TransformDialect::kArgReadOnlyAttrName)) { 929 return emitError() << "expected the matcher argument to be marked readonly"; 930 } 931 932 ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes(); 933 if (resultTypes.size() != getOperation()->getNumResults()) { 934 return emitError() 935 << "expected the matcher to yield as many values as op has results (" 936 << getOperation()->getNumResults() << "), got " 937 << resultTypes.size(); 938 } 939 940 for (auto &&[i, matcherType, resultType] : 941 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) { 942 if (implementSameTransformInterface(matcherType, resultType)) 943 continue; 944 945 return emitError() 946 << "mismatching type interfaces for matcher result and op result #" 947 << i; 948 } 949 950 return success(); 951 } 952 953 //===----------------------------------------------------------------------===// 954 // ForeachMatchOp 955 //===----------------------------------------------------------------------===// 956 957 DiagnosedSilenceableFailure 958 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, 959 transform::TransformResults &results, 960 transform::TransformState &state) { 961 SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>> 962 matchActionPairs; 963 matchActionPairs.reserve(getMatchers().size()); 964 SymbolTableCollection symbolTable; 965 for (auto &&[matcher, action] : 966 llvm::zip_equal(getMatchers(), getActions())) { 967 auto matcherSymbol = 968 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>( 969 getOperation(), cast<SymbolRefAttr>(matcher)); 970 auto actionSymbol = 971 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>( 972 getOperation(), cast<SymbolRefAttr>(action)); 973 assert(matcherSymbol && actionSymbol && 974 "unresolved symbols not caught by the verifier"); 975 976 if (matcherSymbol.isExternal()) 977 return emitDefiniteFailure() << "unresolved external symbol " << matcher; 978 if (actionSymbol.isExternal()) 979 return emitDefiniteFailure() << "unresolved external symbol " << action; 980 981 matchActionPairs.emplace_back(matcherSymbol, actionSymbol); 982 } 983 984 for (Operation *root : state.getPayloadOps(getRoot())) { 985 WalkResult walkResult = root->walk([&](Operation *op) { 986 // If getRestrictRoot is not present, skip over the root op itself so we 987 // don't invalidate it. 988 if (!getRestrictRoot() && op == root) 989 return WalkResult::advance(); 990 991 DEBUG_MATCHER({ 992 DBGS_MATCHER() << "matching "; 993 op->print(llvm::dbgs(), 994 OpPrintingFlags().assumeVerified().skipRegions()); 995 llvm::dbgs() << " @" << op << "\n"; 996 }); 997 998 // Try all the match/action pairs until the first successful match. 999 for (auto [matcher, action] : matchActionPairs) { 1000 SmallVector<SmallVector<MappedValue>> mappings; 1001 DiagnosedSilenceableFailure diag = 1002 matchBlock(matcher.getFunctionBody().front(), op, state, mappings); 1003 if (diag.isDefiniteFailure()) 1004 return WalkResult::interrupt(); 1005 if (diag.isSilenceableFailure()) { 1006 DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() 1007 << " failed: " << diag.getMessage()); 1008 continue; 1009 } 1010 1011 auto scope = state.make_region_scope(action.getFunctionBody()); 1012 for (auto &&[arg, map] : llvm::zip_equal( 1013 action.getFunctionBody().front().getArguments(), mappings)) { 1014 if (failed(state.mapBlockArgument(arg, map))) 1015 return WalkResult::interrupt(); 1016 } 1017 1018 for (Operation &transform : 1019 action.getFunctionBody().front().without_terminator()) { 1020 DiagnosedSilenceableFailure result = 1021 state.applyTransform(cast<TransformOpInterface>(transform)); 1022 if (failed(result.checkAndReport())) 1023 return WalkResult::interrupt(); 1024 } 1025 break; 1026 } 1027 return WalkResult::advance(); 1028 }); 1029 if (walkResult.wasInterrupted()) 1030 return DiagnosedSilenceableFailure::definiteFailure(); 1031 } 1032 1033 // The root operation should not have been affected, so we can just reassign 1034 // the payload to the result. Note that we need to consume the root handle to 1035 // make sure any handles to operations inside, that could have been affected 1036 // by actions, are invalidated. 1037 results.set(llvm::cast<OpResult>(getUpdated()), 1038 state.getPayloadOps(getRoot())); 1039 return DiagnosedSilenceableFailure::success(); 1040 } 1041 1042 void transform::ForeachMatchOp::getEffects( 1043 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1044 // Bail if invalid. 1045 if (getOperation()->getNumOperands() < 1 || 1046 getOperation()->getNumResults() < 1) { 1047 return modifiesPayload(effects); 1048 } 1049 1050 consumesHandle(getRoot(), effects); 1051 producesHandle(getUpdated(), effects); 1052 modifiesPayload(effects); 1053 } 1054 1055 /// Parses the comma-separated list of symbol reference pairs of the format 1056 /// `@matcher -> @action`. 1057 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, 1058 ArrayAttr &matchers, 1059 ArrayAttr &actions) { 1060 StringAttr matcher; 1061 StringAttr action; 1062 SmallVector<Attribute> matcherList; 1063 SmallVector<Attribute> actionList; 1064 do { 1065 if (parser.parseSymbolName(matcher) || parser.parseArrow() || 1066 parser.parseSymbolName(action)) { 1067 return failure(); 1068 } 1069 matcherList.push_back(SymbolRefAttr::get(matcher)); 1070 actionList.push_back(SymbolRefAttr::get(action)); 1071 } while (parser.parseOptionalComma().succeeded()); 1072 1073 matchers = parser.getBuilder().getArrayAttr(matcherList); 1074 actions = parser.getBuilder().getArrayAttr(actionList); 1075 return success(); 1076 } 1077 1078 /// Prints the comma-separated list of symbol reference pairs of the format 1079 /// `@matcher -> @action`. 1080 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, 1081 ArrayAttr matchers, ArrayAttr actions) { 1082 printer.increaseIndent(); 1083 printer.increaseIndent(); 1084 for (auto &&[matcher, action, idx] : llvm::zip_equal( 1085 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) { 1086 printer.printNewline(); 1087 printer << cast<SymbolRefAttr>(matcher) << " -> " 1088 << cast<SymbolRefAttr>(action); 1089 if (idx != matchers.size() - 1) 1090 printer << ", "; 1091 } 1092 printer.decreaseIndent(); 1093 printer.decreaseIndent(); 1094 } 1095 1096 LogicalResult transform::ForeachMatchOp::verify() { 1097 if (getMatchers().size() != getActions().size()) 1098 return emitOpError() << "expected the same number of matchers and actions"; 1099 if (getMatchers().empty()) 1100 return emitOpError() << "expected at least one match/action pair"; 1101 1102 llvm::SmallPtrSet<Attribute, 8> matcherNames; 1103 for (Attribute name : getMatchers()) { 1104 if (matcherNames.insert(name).second) 1105 continue; 1106 emitWarning() << "matcher " << name 1107 << " is used more than once, only the first match will apply"; 1108 } 1109 1110 return success(); 1111 } 1112 1113 /// Checks that the attributes of the function-like operation have correct 1114 /// consumption effect annotations. If `alsoVerifyInternal`, checks for 1115 /// annotations being present even if they can be inferred from the body. 1116 static DiagnosedSilenceableFailure 1117 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, 1118 bool alsoVerifyInternal = false) { 1119 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation()); 1120 llvm::SmallDenseSet<unsigned> consumedArguments; 1121 if (!op.isExternal()) { 1122 transform::getConsumedBlockArguments(op.getFunctionBody().front(), 1123 consumedArguments); 1124 } 1125 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) { 1126 bool isConsumed = 1127 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) != 1128 nullptr; 1129 bool isReadOnly = 1130 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) != 1131 nullptr; 1132 if (isConsumed && isReadOnly) { 1133 return transformOp.emitSilenceableError() 1134 << "argument #" << i << " cannot be both readonly and consumed"; 1135 } 1136 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) { 1137 return transformOp.emitSilenceableError() 1138 << "must provide consumed/readonly status for arguments of " 1139 "external or called ops"; 1140 } 1141 if (op.isExternal()) 1142 continue; 1143 1144 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) { 1145 return transformOp.emitSilenceableError() 1146 << "argument #" << i 1147 << " is consumed in the body but is not marked as such"; 1148 } 1149 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) { 1150 // Cannot use op.emitWarning() here as it would attempt to verify the op 1151 // before printing, resulting in infinite recursion. 1152 emitWarning(op->getLoc()) 1153 << "op argument #" << i 1154 << " is not consumed in the body but is marked as consumed"; 1155 } 1156 } 1157 return DiagnosedSilenceableFailure::success(); 1158 } 1159 1160 LogicalResult transform::ForeachMatchOp::verifySymbolUses( 1161 SymbolTableCollection &symbolTable) { 1162 assert(getMatchers().size() == getActions().size()); 1163 auto consumedAttr = 1164 StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName); 1165 for (auto &&[matcher, action] : 1166 llvm::zip_equal(getMatchers(), getActions())) { 1167 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( 1168 symbolTable.lookupNearestSymbolFrom(getOperation(), 1169 cast<SymbolRefAttr>(matcher))); 1170 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>( 1171 symbolTable.lookupNearestSymbolFrom(getOperation(), 1172 cast<SymbolRefAttr>(action))); 1173 if (!matcherSymbol || 1174 !isa<TransformOpInterface>(matcherSymbol.getOperation())) 1175 return emitError() << "unresolved matcher symbol " << matcher; 1176 if (!actionSymbol || 1177 !isa<TransformOpInterface>(actionSymbol.getOperation())) 1178 return emitError() << "unresolved action symbol " << action; 1179 1180 if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol, 1181 /*emitWarnings=*/false, 1182 /*alsoVerifyInternal=*/true) 1183 .checkAndReport())) { 1184 return failure(); 1185 } 1186 if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol, 1187 /*emitWarnings=*/false, 1188 /*alsoVerifyInternal=*/true) 1189 .checkAndReport())) { 1190 return failure(); 1191 } 1192 1193 ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes(); 1194 ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes(); 1195 if (matcherResults.size() != actionArguments.size()) { 1196 return emitError() << "mismatching number of matcher results and " 1197 "action arguments between " 1198 << matcher << " (" << matcherResults.size() << ") and " 1199 << action << " (" << actionArguments.size() << ")"; 1200 } 1201 for (auto &&[i, matcherType, actionType] : 1202 llvm::enumerate(matcherResults, actionArguments)) { 1203 if (implementSameTransformInterface(matcherType, actionType)) 1204 continue; 1205 1206 return emitError() << "mismatching type interfaces for matcher result " 1207 "and action argument #" 1208 << i; 1209 } 1210 1211 if (!actionSymbol.getResultTypes().empty()) { 1212 InFlightDiagnostic diag = 1213 emitError() << "action symbol is not expected to have results"; 1214 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; 1215 return diag; 1216 } 1217 1218 if (matcherSymbol.getArgumentTypes().size() != 1 || 1219 !implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0], 1220 getRoot().getType())) { 1221 InFlightDiagnostic diag = 1222 emitOpError() << "expects matcher symbol to have one argument with " 1223 "the same transform interface as the first operand"; 1224 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; 1225 return diag; 1226 } 1227 1228 if (matcherSymbol.getArgAttr(0, consumedAttr)) { 1229 InFlightDiagnostic diag = 1230 emitOpError() 1231 << "does not expect matcher symbol to consume its operand"; 1232 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; 1233 return diag; 1234 } 1235 } 1236 return success(); 1237 } 1238 1239 //===----------------------------------------------------------------------===// 1240 // ForeachOp 1241 //===----------------------------------------------------------------------===// 1242 1243 DiagnosedSilenceableFailure 1244 transform::ForeachOp::apply(transform::TransformRewriter &rewriter, 1245 transform::TransformResults &results, 1246 transform::TransformState &state) { 1247 SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {}); 1248 // Store payload ops in a vector because ops may be removed from the mapping 1249 // by the TrackingRewriter while the iteration is in progress. 1250 SmallVector<Operation *> targets = 1251 llvm::to_vector(state.getPayloadOps(getTarget())); 1252 for (Operation *op : targets) { 1253 auto scope = state.make_region_scope(getBody()); 1254 if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) 1255 return DiagnosedSilenceableFailure::definiteFailure(); 1256 1257 // Execute loop body. 1258 for (Operation &transform : getBody().front().without_terminator()) { 1259 DiagnosedSilenceableFailure result = state.applyTransform( 1260 cast<transform::TransformOpInterface>(transform)); 1261 if (!result.succeeded()) 1262 return result; 1263 } 1264 1265 // Append yielded payload ops to result list (if any). 1266 for (unsigned i = 0; i < getNumResults(); ++i) { 1267 auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i)); 1268 resultOps[i].append(yieldedOps.begin(), yieldedOps.end()); 1269 } 1270 } 1271 1272 for (unsigned i = 0; i < getNumResults(); ++i) 1273 results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]); 1274 1275 return DiagnosedSilenceableFailure::success(); 1276 } 1277 1278 void transform::ForeachOp::getEffects( 1279 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1280 BlockArgument iterVar = getIterationVariable(); 1281 if (any_of(getBody().front().without_terminator(), [&](Operation &op) { 1282 return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op)); 1283 })) { 1284 consumesHandle(getTarget(), effects); 1285 } else { 1286 onlyReadsHandle(getTarget(), effects); 1287 } 1288 1289 if (any_of(getBody().front().without_terminator(), [&](Operation &op) { 1290 return doesModifyPayload(cast<TransformOpInterface>(&op)); 1291 })) { 1292 modifiesPayload(effects); 1293 } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) { 1294 return doesReadPayload(cast<TransformOpInterface>(&op)); 1295 })) { 1296 onlyReadsPayload(effects); 1297 } 1298 1299 for (Value result : getResults()) 1300 producesHandle(result, effects); 1301 } 1302 1303 void transform::ForeachOp::getSuccessorRegions( 1304 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 1305 Region *bodyRegion = &getBody(); 1306 if (point.isParent()) { 1307 regions.emplace_back(bodyRegion, bodyRegion->getArguments()); 1308 return; 1309 } 1310 1311 // Branch back to the region or the parent. 1312 assert(point == getBody() && "unexpected region index"); 1313 regions.emplace_back(bodyRegion, bodyRegion->getArguments()); 1314 regions.emplace_back(); 1315 } 1316 1317 OperandRange 1318 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { 1319 // The iteration variable op handle is mapped to a subset (one op to be 1320 // precise) of the payload ops of the ForeachOp operand. 1321 assert(point == getBody() && "unexpected region index"); 1322 return getOperation()->getOperands(); 1323 } 1324 1325 transform::YieldOp transform::ForeachOp::getYieldOp() { 1326 return cast<transform::YieldOp>(getBody().front().getTerminator()); 1327 } 1328 1329 LogicalResult transform::ForeachOp::verify() { 1330 auto yieldOp = getYieldOp(); 1331 if (getNumResults() != yieldOp.getNumOperands()) 1332 return emitOpError() << "expects the same number of results as the " 1333 "terminator has operands"; 1334 for (Value v : yieldOp.getOperands()) 1335 if (!llvm::isa<TransformHandleTypeInterface>(v.getType())) 1336 return yieldOp->emitOpError("expects operands to have types implementing " 1337 "TransformHandleTypeInterface"); 1338 return success(); 1339 } 1340 1341 //===----------------------------------------------------------------------===// 1342 // GetParentOp 1343 //===----------------------------------------------------------------------===// 1344 1345 DiagnosedSilenceableFailure 1346 transform::GetParentOp::apply(transform::TransformRewriter &rewriter, 1347 transform::TransformResults &results, 1348 transform::TransformState &state) { 1349 SmallVector<Operation *> parents; 1350 DenseSet<Operation *> resultSet; 1351 for (Operation *target : state.getPayloadOps(getTarget())) { 1352 Operation *parent = target; 1353 for (int64_t i = 0, e = getNthParent(); i < e; ++i) { 1354 parent = parent->getParentOp(); 1355 while (parent) { 1356 bool checkIsolatedFromAbove = 1357 !getIsolatedFromAbove() || 1358 parent->hasTrait<OpTrait::IsIsolatedFromAbove>(); 1359 bool checkOpName = !getOpName().has_value() || 1360 parent->getName().getStringRef() == *getOpName(); 1361 if (checkIsolatedFromAbove && checkOpName) 1362 break; 1363 parent = parent->getParentOp(); 1364 } 1365 if (!parent) { 1366 if (getAllowEmptyResults()) { 1367 results.set(llvm::cast<OpResult>(getResult()), parents); 1368 return DiagnosedSilenceableFailure::success(); 1369 } 1370 DiagnosedSilenceableFailure diag = 1371 emitSilenceableError() 1372 << "could not find a parent op that matches all requirements"; 1373 diag.attachNote(target->getLoc()) << "target op"; 1374 return diag; 1375 } 1376 } 1377 if (getDeduplicate()) { 1378 if (!resultSet.contains(parent)) { 1379 parents.push_back(parent); 1380 resultSet.insert(parent); 1381 } 1382 } else { 1383 parents.push_back(parent); 1384 } 1385 } 1386 results.set(llvm::cast<OpResult>(getResult()), parents); 1387 return DiagnosedSilenceableFailure::success(); 1388 } 1389 1390 //===----------------------------------------------------------------------===// 1391 // GetConsumersOfResult 1392 //===----------------------------------------------------------------------===// 1393 1394 DiagnosedSilenceableFailure 1395 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter, 1396 transform::TransformResults &results, 1397 transform::TransformState &state) { 1398 int64_t resultNumber = getResultNumber(); 1399 auto payloadOps = state.getPayloadOps(getTarget()); 1400 if (std::empty(payloadOps)) { 1401 results.set(cast<OpResult>(getResult()), {}); 1402 return DiagnosedSilenceableFailure::success(); 1403 } 1404 if (!llvm::hasSingleElement(payloadOps)) 1405 return emitDefiniteFailure() 1406 << "handle must be mapped to exactly one payload op"; 1407 1408 Operation *target = *payloadOps.begin(); 1409 if (target->getNumResults() <= resultNumber) 1410 return emitDefiniteFailure() << "result number overflow"; 1411 results.set(llvm::cast<OpResult>(getResult()), 1412 llvm::to_vector(target->getResult(resultNumber).getUsers())); 1413 return DiagnosedSilenceableFailure::success(); 1414 } 1415 1416 //===----------------------------------------------------------------------===// 1417 // GetDefiningOp 1418 //===----------------------------------------------------------------------===// 1419 1420 DiagnosedSilenceableFailure 1421 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter, 1422 transform::TransformResults &results, 1423 transform::TransformState &state) { 1424 SmallVector<Operation *> definingOps; 1425 for (Value v : state.getPayloadValues(getTarget())) { 1426 if (llvm::isa<BlockArgument>(v)) { 1427 DiagnosedSilenceableFailure diag = 1428 emitSilenceableError() << "cannot get defining op of block argument"; 1429 diag.attachNote(v.getLoc()) << "target value"; 1430 return diag; 1431 } 1432 definingOps.push_back(v.getDefiningOp()); 1433 } 1434 results.set(llvm::cast<OpResult>(getResult()), definingOps); 1435 return DiagnosedSilenceableFailure::success(); 1436 } 1437 1438 //===----------------------------------------------------------------------===// 1439 // GetProducerOfOperand 1440 //===----------------------------------------------------------------------===// 1441 1442 DiagnosedSilenceableFailure 1443 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter, 1444 transform::TransformResults &results, 1445 transform::TransformState &state) { 1446 int64_t operandNumber = getOperandNumber(); 1447 SmallVector<Operation *> producers; 1448 for (Operation *target : state.getPayloadOps(getTarget())) { 1449 Operation *producer = 1450 target->getNumOperands() <= operandNumber 1451 ? nullptr 1452 : target->getOperand(operandNumber).getDefiningOp(); 1453 if (!producer) { 1454 DiagnosedSilenceableFailure diag = 1455 emitSilenceableError() 1456 << "could not find a producer for operand number: " << operandNumber 1457 << " of " << *target; 1458 diag.attachNote(target->getLoc()) << "target op"; 1459 return diag; 1460 } 1461 producers.push_back(producer); 1462 } 1463 results.set(llvm::cast<OpResult>(getResult()), producers); 1464 return DiagnosedSilenceableFailure::success(); 1465 } 1466 1467 //===----------------------------------------------------------------------===// 1468 // GetOperandOp 1469 //===----------------------------------------------------------------------===// 1470 1471 DiagnosedSilenceableFailure 1472 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter, 1473 transform::TransformResults &results, 1474 transform::TransformState &state) { 1475 SmallVector<Value> operands; 1476 for (Operation *target : state.getPayloadOps(getTarget())) { 1477 SmallVector<int64_t> operandPositions; 1478 DiagnosedSilenceableFailure diag = expandTargetSpecification( 1479 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), 1480 target->getNumOperands(), operandPositions); 1481 if (diag.isSilenceableFailure()) { 1482 diag.attachNote(target->getLoc()) 1483 << "while considering positions of this payload operation"; 1484 return diag; 1485 } 1486 llvm::append_range(operands, 1487 llvm::map_range(operandPositions, [&](int64_t pos) { 1488 return target->getOperand(pos); 1489 })); 1490 } 1491 results.setValues(cast<OpResult>(getResult()), operands); 1492 return DiagnosedSilenceableFailure::success(); 1493 } 1494 1495 LogicalResult transform::GetOperandOp::verify() { 1496 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), 1497 getIsInverted(), getIsAll()); 1498 } 1499 1500 //===----------------------------------------------------------------------===// 1501 // GetResultOp 1502 //===----------------------------------------------------------------------===// 1503 1504 DiagnosedSilenceableFailure 1505 transform::GetResultOp::apply(transform::TransformRewriter &rewriter, 1506 transform::TransformResults &results, 1507 transform::TransformState &state) { 1508 SmallVector<Value> opResults; 1509 for (Operation *target : state.getPayloadOps(getTarget())) { 1510 SmallVector<int64_t> resultPositions; 1511 DiagnosedSilenceableFailure diag = expandTargetSpecification( 1512 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), 1513 target->getNumResults(), resultPositions); 1514 if (diag.isSilenceableFailure()) { 1515 diag.attachNote(target->getLoc()) 1516 << "while considering positions of this payload operation"; 1517 return diag; 1518 } 1519 llvm::append_range(opResults, 1520 llvm::map_range(resultPositions, [&](int64_t pos) { 1521 return target->getResult(pos); 1522 })); 1523 } 1524 results.setValues(cast<OpResult>(getResult()), opResults); 1525 return DiagnosedSilenceableFailure::success(); 1526 } 1527 1528 LogicalResult transform::GetResultOp::verify() { 1529 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), 1530 getIsInverted(), getIsAll()); 1531 } 1532 1533 //===----------------------------------------------------------------------===// 1534 // GetTypeOp 1535 //===----------------------------------------------------------------------===// 1536 1537 void transform::GetTypeOp::getEffects( 1538 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1539 onlyReadsHandle(getValue(), effects); 1540 producesHandle(getResult(), effects); 1541 onlyReadsPayload(effects); 1542 } 1543 1544 DiagnosedSilenceableFailure 1545 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter, 1546 transform::TransformResults &results, 1547 transform::TransformState &state) { 1548 SmallVector<Attribute> params; 1549 for (Value value : state.getPayloadValues(getValue())) { 1550 Type type = value.getType(); 1551 if (getElemental()) { 1552 if (auto shaped = dyn_cast<ShapedType>(type)) { 1553 type = shaped.getElementType(); 1554 } 1555 } 1556 params.push_back(TypeAttr::get(type)); 1557 } 1558 results.setParams(getResult().cast<OpResult>(), params); 1559 return DiagnosedSilenceableFailure::success(); 1560 } 1561 1562 //===----------------------------------------------------------------------===// 1563 // IncludeOp 1564 //===----------------------------------------------------------------------===// 1565 1566 /// Applies the transform ops contained in `block`. Maps `results` to the same 1567 /// values as the operands of the block terminator. 1568 static DiagnosedSilenceableFailure 1569 applySequenceBlock(Block &block, transform::FailurePropagationMode mode, 1570 transform::TransformState &state, 1571 transform::TransformResults &results) { 1572 // Apply the sequenced ops one by one. 1573 for (Operation &transform : block.without_terminator()) { 1574 DiagnosedSilenceableFailure result = 1575 state.applyTransform(cast<transform::TransformOpInterface>(transform)); 1576 if (result.isDefiniteFailure()) 1577 return result; 1578 1579 if (result.isSilenceableFailure()) { 1580 if (mode == transform::FailurePropagationMode::Propagate) { 1581 // Propagate empty results in case of early exit. 1582 forwardEmptyOperands(&block, state, results); 1583 return result; 1584 } 1585 (void)result.silence(); 1586 } 1587 } 1588 1589 // Forward the operation mapping for values yielded from the sequence to the 1590 // values produced by the sequence op. 1591 transform::detail::forwardTerminatorOperands(&block, state, results); 1592 return DiagnosedSilenceableFailure::success(); 1593 } 1594 1595 DiagnosedSilenceableFailure 1596 transform::IncludeOp::apply(transform::TransformRewriter &rewriter, 1597 transform::TransformResults &results, 1598 transform::TransformState &state) { 1599 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>( 1600 getOperation(), getTarget()); 1601 assert(callee && "unverified reference to unknown symbol"); 1602 1603 if (callee.isExternal()) 1604 return emitDefiniteFailure() << "unresolved external named sequence"; 1605 1606 // Map operands to block arguments. 1607 SmallVector<SmallVector<MappedValue>> mappings; 1608 detail::prepareValueMappings(mappings, getOperands(), state); 1609 auto scope = state.make_region_scope(callee.getBody()); 1610 for (auto &&[arg, map] : 1611 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) { 1612 if (failed(state.mapBlockArgument(arg, map))) 1613 return DiagnosedSilenceableFailure::definiteFailure(); 1614 } 1615 1616 DiagnosedSilenceableFailure result = applySequenceBlock( 1617 callee.getBody().front(), getFailurePropagationMode(), state, results); 1618 mappings.clear(); 1619 detail::prepareValueMappings( 1620 mappings, callee.getBody().front().getTerminator()->getOperands(), state); 1621 for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings)) 1622 results.setMappedValues(result, mapping); 1623 return result; 1624 } 1625 1626 static DiagnosedSilenceableFailure 1627 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings); 1628 1629 void transform::IncludeOp::getEffects( 1630 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1631 // Always mark as modifying the payload. 1632 // TODO: a mechanism to annotate effects on payload. Even when all handles are 1633 // only read, the payload may still be modified, so we currently stay on the 1634 // conservative side and always indicate modification. This may prevent some 1635 // code reordering. 1636 modifiesPayload(effects); 1637 1638 // Results are always produced. 1639 producesHandle(getResults(), effects); 1640 1641 // Adds default effects to operands and results. This will be added if 1642 // preconditions fail so the trait verifier doesn't complain about missing 1643 // effects and the real precondition failure is reported later on. 1644 auto defaultEffects = [&] { onlyReadsHandle(getOperands(), effects); }; 1645 1646 // Bail if the callee is unknown. This may run as part of the verification 1647 // process before we verified the validity of the callee or of this op. 1648 auto target = 1649 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName()); 1650 if (!target) 1651 return defaultEffects(); 1652 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>( 1653 getOperation(), getTarget()); 1654 if (!callee) 1655 return defaultEffects(); 1656 DiagnosedSilenceableFailure earlyVerifierResult = 1657 verifyNamedSequenceOp(callee, /*emitWarnings=*/false); 1658 if (!earlyVerifierResult.succeeded()) { 1659 (void)earlyVerifierResult.silence(); 1660 return defaultEffects(); 1661 } 1662 1663 for (unsigned i = 0, e = getNumOperands(); i < e; ++i) { 1664 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName)) 1665 consumesHandle(getOperand(i), effects); 1666 else 1667 onlyReadsHandle(getOperand(i), effects); 1668 } 1669 } 1670 1671 LogicalResult 1672 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1673 // Access through indirection and do additional checking because this may be 1674 // running before the main op verifier. 1675 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target"); 1676 if (!targetAttr) 1677 return emitOpError() << "expects a 'target' symbol reference attribute"; 1678 1679 auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>( 1680 *this, targetAttr); 1681 if (!target) 1682 return emitOpError() << "does not reference a named transform sequence"; 1683 1684 FunctionType fnType = target.getFunctionType(); 1685 if (fnType.getNumInputs() != getNumOperands()) 1686 return emitError("incorrect number of operands for callee"); 1687 1688 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { 1689 if (getOperand(i).getType() != fnType.getInput(i)) { 1690 return emitOpError("operand type mismatch: expected operand type ") 1691 << fnType.getInput(i) << ", but provided " 1692 << getOperand(i).getType() << " for operand number " << i; 1693 } 1694 } 1695 1696 if (fnType.getNumResults() != getNumResults()) 1697 return emitError("incorrect number of results for callee"); 1698 1699 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { 1700 Type resultType = getResult(i).getType(); 1701 Type funcType = fnType.getResult(i); 1702 if (!implementSameTransformInterface(resultType, funcType)) { 1703 return emitOpError() << "type of result #" << i 1704 << " must implement the same transform dialect " 1705 "interface as the corresponding callee result"; 1706 } 1707 } 1708 1709 return verifyFunctionLikeConsumeAnnotations( 1710 cast<FunctionOpInterface>(*target), /*emitWarnings=*/false, 1711 /*alsoVerifyInternal=*/true) 1712 .checkAndReport(); 1713 } 1714 1715 //===----------------------------------------------------------------------===// 1716 // MatchOperationEmptyOp 1717 //===----------------------------------------------------------------------===// 1718 1719 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation( 1720 ::std::optional<::mlir::Operation *> maybeCurrent, 1721 transform::TransformResults &results, transform::TransformState &state) { 1722 if (!maybeCurrent.has_value()) { 1723 DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; }); 1724 return DiagnosedSilenceableFailure::success(); 1725 } 1726 DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; }); 1727 return emitSilenceableError() << "operation is not empty"; 1728 } 1729 1730 //===----------------------------------------------------------------------===// 1731 // MatchOperationNameOp 1732 //===----------------------------------------------------------------------===// 1733 1734 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation( 1735 Operation *current, transform::TransformResults &results, 1736 transform::TransformState &state) { 1737 StringRef currentOpName = current->getName().getStringRef(); 1738 for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) { 1739 if (acceptedAttr.getValue() == currentOpName) 1740 return DiagnosedSilenceableFailure::success(); 1741 } 1742 return emitSilenceableError() << "wrong operation name"; 1743 } 1744 1745 //===----------------------------------------------------------------------===// 1746 // MatchParamCmpIOp 1747 //===----------------------------------------------------------------------===// 1748 1749 DiagnosedSilenceableFailure 1750 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter, 1751 transform::TransformResults &results, 1752 transform::TransformState &state) { 1753 auto signedAPIntAsString = [&](APInt value) { 1754 std::string str; 1755 llvm::raw_string_ostream os(str); 1756 value.print(os, /*isSigned=*/true); 1757 return os.str(); 1758 }; 1759 1760 ArrayRef<Attribute> params = state.getParams(getParam()); 1761 ArrayRef<Attribute> references = state.getParams(getReference()); 1762 1763 if (params.size() != references.size()) { 1764 return emitSilenceableError() 1765 << "parameters have different payload lengths (" << params.size() 1766 << " vs " << references.size() << ")"; 1767 } 1768 1769 for (auto &&[i, param, reference] : llvm::enumerate(params, references)) { 1770 auto intAttr = llvm::dyn_cast<IntegerAttr>(param); 1771 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference); 1772 if (!intAttr || !refAttr) { 1773 return emitDefiniteFailure() 1774 << "non-integer parameter value not expected"; 1775 } 1776 if (intAttr.getType() != refAttr.getType()) { 1777 return emitDefiniteFailure() 1778 << "mismatching integer attribute types in parameter #" << i; 1779 } 1780 APInt value = intAttr.getValue(); 1781 APInt refValue = refAttr.getValue(); 1782 1783 // TODO: this copy will not be necessary in C++20. 1784 int64_t position = i; 1785 auto reportError = [&](StringRef direction) { 1786 DiagnosedSilenceableFailure diag = 1787 emitSilenceableError() << "expected parameter to be " << direction 1788 << " " << signedAPIntAsString(refValue) 1789 << ", got " << signedAPIntAsString(value); 1790 diag.attachNote(getParam().getLoc()) 1791 << "value # " << position 1792 << " associated with the parameter defined here"; 1793 return diag; 1794 }; 1795 1796 switch (getPredicate()) { 1797 case MatchCmpIPredicate::eq: 1798 if (value.eq(refValue)) 1799 break; 1800 return reportError("equal to"); 1801 case MatchCmpIPredicate::ne: 1802 if (value.ne(refValue)) 1803 break; 1804 return reportError("not equal to"); 1805 case MatchCmpIPredicate::lt: 1806 if (value.slt(refValue)) 1807 break; 1808 return reportError("less than"); 1809 case MatchCmpIPredicate::le: 1810 if (value.sle(refValue)) 1811 break; 1812 return reportError("less than or equal to"); 1813 case MatchCmpIPredicate::gt: 1814 if (value.sgt(refValue)) 1815 break; 1816 return reportError("greater than"); 1817 case MatchCmpIPredicate::ge: 1818 if (value.sge(refValue)) 1819 break; 1820 return reportError("greater than or equal to"); 1821 } 1822 } 1823 return DiagnosedSilenceableFailure::success(); 1824 } 1825 1826 void transform::MatchParamCmpIOp::getEffects( 1827 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1828 onlyReadsHandle(getParam(), effects); 1829 onlyReadsHandle(getReference(), effects); 1830 } 1831 1832 //===----------------------------------------------------------------------===// 1833 // ParamConstantOp 1834 //===----------------------------------------------------------------------===// 1835 1836 DiagnosedSilenceableFailure 1837 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter, 1838 transform::TransformResults &results, 1839 transform::TransformState &state) { 1840 results.setParams(cast<OpResult>(getParam()), {getValue()}); 1841 return DiagnosedSilenceableFailure::success(); 1842 } 1843 1844 //===----------------------------------------------------------------------===// 1845 // MergeHandlesOp 1846 //===----------------------------------------------------------------------===// 1847 1848 DiagnosedSilenceableFailure 1849 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter, 1850 transform::TransformResults &results, 1851 transform::TransformState &state) { 1852 ValueRange handles = getHandles(); 1853 if (isa<TransformHandleTypeInterface>(handles.front().getType())) { 1854 SmallVector<Operation *> operations; 1855 for (Value operand : handles) 1856 llvm::append_range(operations, state.getPayloadOps(operand)); 1857 if (!getDeduplicate()) { 1858 results.set(llvm::cast<OpResult>(getResult()), operations); 1859 return DiagnosedSilenceableFailure::success(); 1860 } 1861 1862 SetVector<Operation *> uniqued(operations.begin(), operations.end()); 1863 results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef()); 1864 return DiagnosedSilenceableFailure::success(); 1865 } 1866 1867 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) { 1868 SmallVector<Attribute> attrs; 1869 for (Value attribute : handles) 1870 llvm::append_range(attrs, state.getParams(attribute)); 1871 if (!getDeduplicate()) { 1872 results.setParams(cast<OpResult>(getResult()), attrs); 1873 return DiagnosedSilenceableFailure::success(); 1874 } 1875 1876 SetVector<Attribute> uniqued(attrs.begin(), attrs.end()); 1877 results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef()); 1878 return DiagnosedSilenceableFailure::success(); 1879 } 1880 1881 assert( 1882 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) && 1883 "expected value handle type"); 1884 SmallVector<Value> payloadValues; 1885 for (Value value : handles) 1886 llvm::append_range(payloadValues, state.getPayloadValues(value)); 1887 if (!getDeduplicate()) { 1888 results.setValues(cast<OpResult>(getResult()), payloadValues); 1889 return DiagnosedSilenceableFailure::success(); 1890 } 1891 1892 SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end()); 1893 results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef()); 1894 return DiagnosedSilenceableFailure::success(); 1895 } 1896 1897 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() { 1898 // Handles may be the same if deduplicating is enabled. 1899 return getDeduplicate(); 1900 } 1901 1902 void transform::MergeHandlesOp::getEffects( 1903 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1904 onlyReadsHandle(getHandles(), effects); 1905 producesHandle(getResult(), effects); 1906 1907 // There are no effects on the Payload IR as this is only a handle 1908 // manipulation. 1909 } 1910 1911 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) { 1912 if (getDeduplicate() || getHandles().size() != 1) 1913 return {}; 1914 1915 // If deduplication is not required and there is only one operand, it can be 1916 // used directly instead of merging. 1917 return getHandles().front(); 1918 } 1919 1920 //===----------------------------------------------------------------------===// 1921 // NamedSequenceOp 1922 //===----------------------------------------------------------------------===// 1923 1924 DiagnosedSilenceableFailure 1925 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter, 1926 transform::TransformResults &results, 1927 transform::TransformState &state) { 1928 if (isExternal()) 1929 return emitDefiniteFailure() << "unresolved external named sequence"; 1930 1931 // Map the entry block argument to the list of operations. 1932 // Note: this is the same implementation as PossibleTopLevelTransformOp but 1933 // without attaching the interface / trait since that is tailored to a 1934 // dangling top-level op that does not get "called". 1935 auto scope = state.make_region_scope(getBody()); 1936 if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments( 1937 state, this->getOperation(), getBody()))) 1938 return DiagnosedSilenceableFailure::definiteFailure(); 1939 1940 return applySequenceBlock(getBody().front(), 1941 FailurePropagationMode::Propagate, state, results); 1942 } 1943 1944 void transform::NamedSequenceOp::getEffects( 1945 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 1946 1947 ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser, 1948 OperationState &result) { 1949 return function_interface_impl::parseFunctionOp( 1950 parser, result, /*allowVariadic=*/false, 1951 getFunctionTypeAttrName(result.name), 1952 [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results, 1953 function_interface_impl::VariadicFlag, 1954 std::string &) { return builder.getFunctionType(inputs, results); }, 1955 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 1956 } 1957 1958 void transform::NamedSequenceOp::print(OpAsmPrinter &printer) { 1959 function_interface_impl::printFunctionOp( 1960 printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false, 1961 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(), 1962 getResAttrsAttrName()); 1963 } 1964 1965 /// Verifies that a symbol function-like transform dialect operation has the 1966 /// signature and the terminator that have conforming types, i.e., types 1967 /// implementing the same transform dialect type interface. If `allowExternal` 1968 /// is set, allow external symbols (declarations) and don't check the terminator 1969 /// as it may not exist. 1970 static DiagnosedSilenceableFailure 1971 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) { 1972 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) { 1973 DiagnosedSilenceableFailure diag = 1974 emitSilenceableFailure(op) 1975 << "cannot be defined inside another transform op"; 1976 diag.attachNote(parent.getLoc()) << "ancestor transform op"; 1977 return diag; 1978 } 1979 1980 if (op.isExternal() || op.getFunctionBody().empty()) { 1981 if (allowExternal) 1982 return DiagnosedSilenceableFailure::success(); 1983 1984 return emitSilenceableFailure(op) << "cannot be external"; 1985 } 1986 1987 if (op.getFunctionBody().front().empty()) 1988 return emitSilenceableFailure(op) << "expected a non-empty body block"; 1989 1990 Operation *terminator = &op.getFunctionBody().front().back(); 1991 if (!isa<transform::YieldOp>(terminator)) { 1992 DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) 1993 << "expected '" 1994 << transform::YieldOp::getOperationName() 1995 << "' as terminator"; 1996 diag.attachNote(terminator->getLoc()) << "terminator"; 1997 return diag; 1998 } 1999 2000 if (terminator->getNumOperands() != op.getResultTypes().size()) { 2001 return emitSilenceableFailure(terminator) 2002 << "expected terminator to have as many operands as the parent op " 2003 "has results"; 2004 } 2005 for (auto [i, operandType, resultType] : llvm::zip_equal( 2006 llvm::seq<unsigned>(0, terminator->getNumOperands()), 2007 terminator->getOperands().getType(), op.getResultTypes())) { 2008 if (operandType == resultType) 2009 continue; 2010 return emitSilenceableFailure(terminator) 2011 << "the type of the terminator operand #" << i 2012 << " must match the type of the corresponding parent op result (" 2013 << operandType << " vs " << resultType << ")"; 2014 } 2015 2016 return DiagnosedSilenceableFailure::success(); 2017 } 2018 2019 /// Verification of a NamedSequenceOp. This does not report the error 2020 /// immediately, so it can be used to check for op's well-formedness before the 2021 /// verifier runs, e.g., during trait verification. 2022 static DiagnosedSilenceableFailure 2023 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) { 2024 if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) { 2025 if (!parent->getAttr( 2026 transform::TransformDialect::kWithNamedSequenceAttrName)) { 2027 DiagnosedSilenceableFailure diag = 2028 emitSilenceableFailure(op) 2029 << "expects the parent symbol table to have the '" 2030 << transform::TransformDialect::kWithNamedSequenceAttrName 2031 << "' attribute"; 2032 diag.attachNote(parent->getLoc()) << "symbol table operation"; 2033 return diag; 2034 } 2035 } 2036 2037 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) { 2038 DiagnosedSilenceableFailure diag = 2039 emitSilenceableFailure(op) 2040 << "cannot be defined inside another transform op"; 2041 diag.attachNote(parent.getLoc()) << "ancestor transform op"; 2042 return diag; 2043 } 2044 2045 if (op.isExternal() || op.getBody().empty()) 2046 return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op), 2047 emitWarnings); 2048 2049 if (op.getBody().front().empty()) 2050 return emitSilenceableFailure(op) << "expected a non-empty body block"; 2051 2052 Operation *terminator = &op.getBody().front().back(); 2053 if (!isa<transform::YieldOp>(terminator)) { 2054 DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) 2055 << "expected '" 2056 << transform::YieldOp::getOperationName() 2057 << "' as terminator"; 2058 diag.attachNote(terminator->getLoc()) << "terminator"; 2059 return diag; 2060 } 2061 2062 if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) { 2063 return emitSilenceableFailure(terminator) 2064 << "expected terminator to have as many operands as the parent op " 2065 "has results"; 2066 } 2067 for (auto [i, operandType, resultType] : 2068 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()), 2069 terminator->getOperands().getType(), 2070 op.getFunctionType().getResults())) { 2071 if (operandType == resultType) 2072 continue; 2073 return emitSilenceableFailure(terminator) 2074 << "the type of the terminator operand #" << i 2075 << " must match the type of the corresponding parent op result (" 2076 << operandType << " vs " << resultType << ")"; 2077 } 2078 2079 auto funcOp = cast<FunctionOpInterface>(*op); 2080 DiagnosedSilenceableFailure diag = 2081 verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings); 2082 if (!diag.succeeded()) 2083 return diag; 2084 2085 return verifyYieldingSingleBlockOp(funcOp, 2086 /*allowExternal=*/true); 2087 } 2088 2089 LogicalResult transform::NamedSequenceOp::verify() { 2090 // Actual verification happens in a separate function for reusability. 2091 return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport(); 2092 } 2093 2094 template <typename FnTy> 2095 static void buildSequenceBody(OpBuilder &builder, OperationState &state, 2096 Type bbArgType, TypeRange extraBindingTypes, 2097 FnTy bodyBuilder) { 2098 SmallVector<Type> types; 2099 types.reserve(1 + extraBindingTypes.size()); 2100 types.push_back(bbArgType); 2101 llvm::append_range(types, extraBindingTypes); 2102 2103 OpBuilder::InsertionGuard guard(builder); 2104 Region *region = state.regions.back().get(); 2105 Block *bodyBlock = 2106 builder.createBlock(region, region->begin(), types, 2107 SmallVector<Location>(types.size(), state.location)); 2108 2109 // Populate body. 2110 builder.setInsertionPointToStart(bodyBlock); 2111 if constexpr (llvm::function_traits<FnTy>::num_args == 3) { 2112 bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); 2113 } else { 2114 bodyBuilder(builder, state.location, bodyBlock->getArgument(0), 2115 bodyBlock->getArguments().drop_front()); 2116 } 2117 } 2118 2119 void transform::NamedSequenceOp::build(OpBuilder &builder, 2120 OperationState &state, StringRef symName, 2121 Type rootType, TypeRange resultTypes, 2122 SequenceBodyBuilderFn bodyBuilder, 2123 ArrayRef<NamedAttribute> attrs, 2124 ArrayRef<DictionaryAttr> argAttrs) { 2125 state.addAttribute(SymbolTable::getSymbolAttrName(), 2126 builder.getStringAttr(symName)); 2127 state.addAttribute(getFunctionTypeAttrName(state.name), 2128 TypeAttr::get(FunctionType::get(builder.getContext(), 2129 rootType, resultTypes))); 2130 state.attributes.append(attrs.begin(), attrs.end()); 2131 state.addRegion(); 2132 2133 buildSequenceBody(builder, state, rootType, 2134 /*extraBindingTypes=*/TypeRange(), bodyBuilder); 2135 } 2136 2137 //===----------------------------------------------------------------------===// 2138 // NumAssociationsOp 2139 //===----------------------------------------------------------------------===// 2140 2141 DiagnosedSilenceableFailure 2142 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter, 2143 transform::TransformResults &results, 2144 transform::TransformState &state) { 2145 size_t numAssociations = 2146 llvm::TypeSwitch<Type, size_t>(getHandle().getType()) 2147 .Case([&](TransformHandleTypeInterface opHandle) { 2148 return llvm::range_size(state.getPayloadOps(getHandle())); 2149 }) 2150 .Case([&](TransformValueHandleTypeInterface valueHandle) { 2151 return llvm::range_size(state.getPayloadValues(getHandle())); 2152 }) 2153 .Case([&](TransformParamTypeInterface param) { 2154 return llvm::range_size(state.getParams(getHandle())); 2155 }) 2156 .Default([](Type) { 2157 llvm_unreachable("unknown kind of transform dialect type"); 2158 return 0; 2159 }); 2160 results.setParams(getNum().cast<OpResult>(), 2161 rewriter.getI64IntegerAttr(numAssociations)); 2162 return DiagnosedSilenceableFailure::success(); 2163 } 2164 2165 LogicalResult transform::NumAssociationsOp::verify() { 2166 // Verify that the result type accepts an i64 attribute as payload. 2167 auto resultType = getNum().getType().cast<TransformParamTypeInterface>(); 2168 return resultType 2169 .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)}) 2170 .checkAndReport(); 2171 } 2172 2173 //===----------------------------------------------------------------------===// 2174 // SelectOp 2175 //===----------------------------------------------------------------------===// 2176 2177 DiagnosedSilenceableFailure 2178 transform::SelectOp::apply(transform::TransformRewriter &rewriter, 2179 transform::TransformResults &results, 2180 transform::TransformState &state) { 2181 SmallVector<Operation *> result; 2182 auto payloadOps = state.getPayloadOps(getTarget()); 2183 for (Operation *op : payloadOps) { 2184 if (op->getName().getStringRef() == getOpName()) 2185 result.push_back(op); 2186 } 2187 results.set(cast<OpResult>(getResult()), result); 2188 return DiagnosedSilenceableFailure::success(); 2189 } 2190 2191 //===----------------------------------------------------------------------===// 2192 // SplitHandleOp 2193 //===----------------------------------------------------------------------===// 2194 2195 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result, 2196 Value target, int64_t numResultHandles) { 2197 result.addOperands(target); 2198 result.addTypes(SmallVector<Type>(numResultHandles, target.getType())); 2199 } 2200 2201 DiagnosedSilenceableFailure 2202 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, 2203 transform::TransformResults &results, 2204 transform::TransformState &state) { 2205 int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle())); 2206 auto produceNumOpsError = [&]() { 2207 return emitSilenceableError() 2208 << getHandle() << " expected to contain " << this->getNumResults() 2209 << " payload ops but it contains " << numPayloadOps 2210 << " payload ops"; 2211 }; 2212 2213 // Fail if there are more payload ops than results and no overflow result was 2214 // specified. 2215 if (numPayloadOps > getNumResults() && !getOverflowResult().has_value()) 2216 return produceNumOpsError(); 2217 2218 // Fail if there are more results than payload ops. Unless: 2219 // - "fail_on_payload_too_small" is set to "false", or 2220 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops. 2221 if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() && 2222 !(numPayloadOps == 0 && getPassThroughEmptyHandle())) 2223 return produceNumOpsError(); 2224 2225 // Distribute payload ops. 2226 SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {}); 2227 if (getOverflowResult()) 2228 resultHandles[*getOverflowResult()].reserve(numPayloadOps - 2229 getNumResults()); 2230 for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) { 2231 int64_t resultNum = en.index(); 2232 if (resultNum >= getNumResults()) 2233 resultNum = *getOverflowResult(); 2234 resultHandles[resultNum].push_back(en.value()); 2235 } 2236 2237 // Set transform op results. 2238 for (auto &&it : llvm::enumerate(resultHandles)) 2239 results.set(llvm::cast<OpResult>(getResult(it.index())), it.value()); 2240 2241 return DiagnosedSilenceableFailure::success(); 2242 } 2243 2244 void transform::SplitHandleOp::getEffects( 2245 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2246 onlyReadsHandle(getHandle(), effects); 2247 producesHandle(getResults(), effects); 2248 // There are no effects on the Payload IR as this is only a handle 2249 // manipulation. 2250 } 2251 2252 LogicalResult transform::SplitHandleOp::verify() { 2253 if (getOverflowResult().has_value() && 2254 !(*getOverflowResult() < getNumResults())) 2255 return emitOpError("overflow_result is not a valid result index"); 2256 return success(); 2257 } 2258 2259 //===----------------------------------------------------------------------===// 2260 // ReplicateOp 2261 //===----------------------------------------------------------------------===// 2262 2263 DiagnosedSilenceableFailure 2264 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter, 2265 transform::TransformResults &results, 2266 transform::TransformState &state) { 2267 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); 2268 for (const auto &en : llvm::enumerate(getHandles())) { 2269 Value handle = en.value(); 2270 if (isa<TransformHandleTypeInterface>(handle.getType())) { 2271 SmallVector<Operation *> current = 2272 llvm::to_vector(state.getPayloadOps(handle)); 2273 SmallVector<Operation *> payload; 2274 payload.reserve(numRepetitions * current.size()); 2275 for (unsigned i = 0; i < numRepetitions; ++i) 2276 llvm::append_range(payload, current); 2277 results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload); 2278 } else { 2279 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) && 2280 "expected param type"); 2281 ArrayRef<Attribute> current = state.getParams(handle); 2282 SmallVector<Attribute> params; 2283 params.reserve(numRepetitions * current.size()); 2284 for (unsigned i = 0; i < numRepetitions; ++i) 2285 llvm::append_range(params, current); 2286 results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]), 2287 params); 2288 } 2289 } 2290 return DiagnosedSilenceableFailure::success(); 2291 } 2292 2293 void transform::ReplicateOp::getEffects( 2294 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2295 onlyReadsHandle(getPattern(), effects); 2296 onlyReadsHandle(getHandles(), effects); 2297 producesHandle(getReplicated(), effects); 2298 } 2299 2300 //===----------------------------------------------------------------------===// 2301 // SequenceOp 2302 //===----------------------------------------------------------------------===// 2303 2304 DiagnosedSilenceableFailure 2305 transform::SequenceOp::apply(transform::TransformRewriter &rewriter, 2306 transform::TransformResults &results, 2307 transform::TransformState &state) { 2308 // Map the entry block argument to the list of operations. 2309 auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 2310 if (failed(mapBlockArguments(state))) 2311 return DiagnosedSilenceableFailure::definiteFailure(); 2312 2313 return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state, 2314 results); 2315 } 2316 2317 static ParseResult parseSequenceOpOperands( 2318 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root, 2319 Type &rootType, 2320 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings, 2321 SmallVectorImpl<Type> &extraBindingTypes) { 2322 OpAsmParser::UnresolvedOperand rootOperand; 2323 OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand); 2324 if (!hasRoot.has_value()) { 2325 root = std::nullopt; 2326 return success(); 2327 } 2328 if (failed(hasRoot.value())) 2329 return failure(); 2330 root = rootOperand; 2331 2332 if (succeeded(parser.parseOptionalComma())) { 2333 if (failed(parser.parseOperandList(extraBindings))) 2334 return failure(); 2335 } 2336 if (failed(parser.parseColon())) 2337 return failure(); 2338 2339 // The paren is truly optional. 2340 (void)parser.parseOptionalLParen(); 2341 2342 if (failed(parser.parseType(rootType))) { 2343 return failure(); 2344 } 2345 2346 if (!extraBindings.empty()) { 2347 if (parser.parseComma() || parser.parseTypeList(extraBindingTypes)) 2348 return failure(); 2349 } 2350 2351 if (extraBindingTypes.size() != extraBindings.size()) { 2352 return parser.emitError(parser.getNameLoc(), 2353 "expected types to be provided for all operands"); 2354 } 2355 2356 // The paren is truly optional. 2357 (void)parser.parseOptionalRParen(); 2358 return success(); 2359 } 2360 2361 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, 2362 Value root, Type rootType, 2363 ValueRange extraBindings, 2364 TypeRange extraBindingTypes) { 2365 if (!root) 2366 return; 2367 2368 printer << root; 2369 bool hasExtras = !extraBindings.empty(); 2370 if (hasExtras) { 2371 printer << ", "; 2372 printer.printOperands(extraBindings); 2373 } 2374 2375 printer << " : "; 2376 if (hasExtras) 2377 printer << "("; 2378 2379 printer << rootType; 2380 if (hasExtras) { 2381 printer << ", "; 2382 llvm::interleaveComma(extraBindingTypes, printer.getStream()); 2383 printer << ")"; 2384 } 2385 } 2386 2387 /// Returns `true` if the given op operand may be consuming the handle value in 2388 /// the Transform IR. That is, if it may have a Free effect on it. 2389 static bool isValueUsePotentialConsumer(OpOperand &use) { 2390 // Conservatively assume the effect being present in absence of the interface. 2391 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner()); 2392 if (!iface) 2393 return true; 2394 2395 return isHandleConsumed(use.get(), iface); 2396 } 2397 2398 LogicalResult 2399 checkDoubleConsume(Value value, 2400 function_ref<InFlightDiagnostic()> reportError) { 2401 OpOperand *potentialConsumer = nullptr; 2402 for (OpOperand &use : value.getUses()) { 2403 if (!isValueUsePotentialConsumer(use)) 2404 continue; 2405 2406 if (!potentialConsumer) { 2407 potentialConsumer = &use; 2408 continue; 2409 } 2410 2411 InFlightDiagnostic diag = reportError() 2412 << " has more than one potential consumer"; 2413 diag.attachNote(potentialConsumer->getOwner()->getLoc()) 2414 << "used here as operand #" << potentialConsumer->getOperandNumber(); 2415 diag.attachNote(use.getOwner()->getLoc()) 2416 << "used here as operand #" << use.getOperandNumber(); 2417 return diag; 2418 } 2419 2420 return success(); 2421 } 2422 2423 LogicalResult transform::SequenceOp::verify() { 2424 assert(getBodyBlock()->getNumArguments() >= 1 && 2425 "the number of arguments must have been verified to be more than 1 by " 2426 "PossibleTopLevelTransformOpTrait"); 2427 2428 if (!getRoot() && !getExtraBindings().empty()) { 2429 return emitOpError() 2430 << "does not expect extra operands when used as top-level"; 2431 } 2432 2433 // Check if a block argument has more than one consuming use. 2434 for (BlockArgument arg : getBodyBlock()->getArguments()) { 2435 if (failed(checkDoubleConsume(arg, [this, arg]() { 2436 return (emitOpError() << "block argument #" << arg.getArgNumber()); 2437 }))) { 2438 return failure(); 2439 } 2440 } 2441 2442 // Check properties of the nested operations they cannot check themselves. 2443 for (Operation &child : *getBodyBlock()) { 2444 if (!isa<TransformOpInterface>(child) && 2445 &child != &getBodyBlock()->back()) { 2446 InFlightDiagnostic diag = 2447 emitOpError() 2448 << "expected children ops to implement TransformOpInterface"; 2449 diag.attachNote(child.getLoc()) << "op without interface"; 2450 return diag; 2451 } 2452 2453 for (OpResult result : child.getResults()) { 2454 auto report = [&]() { 2455 return (child.emitError() << "result #" << result.getResultNumber()); 2456 }; 2457 if (failed(checkDoubleConsume(result, report))) 2458 return failure(); 2459 } 2460 } 2461 2462 if (!getBodyBlock()->mightHaveTerminator()) 2463 return emitOpError() << "expects to have a terminator in the body"; 2464 2465 if (getBodyBlock()->getTerminator()->getOperandTypes() != 2466 getOperation()->getResultTypes()) { 2467 InFlightDiagnostic diag = emitOpError() 2468 << "expects the types of the terminator operands " 2469 "to match the types of the result"; 2470 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 2471 return diag; 2472 } 2473 return success(); 2474 } 2475 2476 void transform::SequenceOp::getEffects( 2477 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2478 getPotentialTopLevelEffects(effects); 2479 } 2480 2481 OperandRange 2482 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { 2483 assert(point == getBody() && "unexpected region index"); 2484 if (getOperation()->getNumOperands() > 0) 2485 return getOperation()->getOperands(); 2486 return OperandRange(getOperation()->operand_end(), 2487 getOperation()->operand_end()); 2488 } 2489 2490 void transform::SequenceOp::getSuccessorRegions( 2491 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 2492 if (point.isParent()) { 2493 Region *bodyRegion = &getBody(); 2494 regions.emplace_back(bodyRegion, getNumOperands() != 0 2495 ? bodyRegion->getArguments() 2496 : Block::BlockArgListType()); 2497 return; 2498 } 2499 2500 assert(point == getBody() && "unexpected region index"); 2501 regions.emplace_back(getOperation()->getResults()); 2502 } 2503 2504 void transform::SequenceOp::getRegionInvocationBounds( 2505 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 2506 (void)operands; 2507 bounds.emplace_back(1, 1); 2508 } 2509 2510 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 2511 TypeRange resultTypes, 2512 FailurePropagationMode failurePropagationMode, 2513 Value root, 2514 SequenceBodyBuilderFn bodyBuilder) { 2515 build(builder, state, resultTypes, failurePropagationMode, root, 2516 /*extra_bindings=*/ValueRange()); 2517 Type bbArgType = root.getType(); 2518 buildSequenceBody(builder, state, bbArgType, 2519 /*extraBindingTypes=*/TypeRange(), bodyBuilder); 2520 } 2521 2522 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 2523 TypeRange resultTypes, 2524 FailurePropagationMode failurePropagationMode, 2525 Value root, ValueRange extraBindings, 2526 SequenceBodyBuilderArgsFn bodyBuilder) { 2527 build(builder, state, resultTypes, failurePropagationMode, root, 2528 extraBindings); 2529 buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(), 2530 bodyBuilder); 2531 } 2532 2533 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 2534 TypeRange resultTypes, 2535 FailurePropagationMode failurePropagationMode, 2536 Type bbArgType, 2537 SequenceBodyBuilderFn bodyBuilder) { 2538 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), 2539 /*extra_bindings=*/ValueRange()); 2540 buildSequenceBody(builder, state, bbArgType, 2541 /*extraBindingTypes=*/TypeRange(), bodyBuilder); 2542 } 2543 2544 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 2545 TypeRange resultTypes, 2546 FailurePropagationMode failurePropagationMode, 2547 Type bbArgType, TypeRange extraBindingTypes, 2548 SequenceBodyBuilderArgsFn bodyBuilder) { 2549 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), 2550 /*extra_bindings=*/ValueRange()); 2551 buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder); 2552 } 2553 2554 //===----------------------------------------------------------------------===// 2555 // PrintOp 2556 //===----------------------------------------------------------------------===// 2557 2558 void transform::PrintOp::build(OpBuilder &builder, OperationState &result, 2559 StringRef name) { 2560 if (!name.empty()) 2561 result.getOrAddProperties<Properties>().name = builder.getStringAttr(name); 2562 } 2563 2564 void transform::PrintOp::build(OpBuilder &builder, OperationState &result, 2565 Value target, StringRef name) { 2566 result.addOperands({target}); 2567 build(builder, result, name); 2568 } 2569 2570 DiagnosedSilenceableFailure 2571 transform::PrintOp::apply(transform::TransformRewriter &rewriter, 2572 transform::TransformResults &results, 2573 transform::TransformState &state) { 2574 llvm::outs() << "[[[ IR printer: "; 2575 if (getName().has_value()) 2576 llvm::outs() << *getName() << " "; 2577 2578 if (!getTarget()) { 2579 llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n"; 2580 return DiagnosedSilenceableFailure::success(); 2581 } 2582 2583 llvm::outs() << "]]]\n"; 2584 for (Operation *target : state.getPayloadOps(getTarget())) 2585 llvm::outs() << *target << "\n"; 2586 2587 return DiagnosedSilenceableFailure::success(); 2588 } 2589 2590 void transform::PrintOp::getEffects( 2591 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2592 onlyReadsHandle(getTarget(), effects); 2593 onlyReadsPayload(effects); 2594 2595 // There is no resource for stderr file descriptor, so just declare print 2596 // writes into the default resource. 2597 effects.emplace_back(MemoryEffects::Write::get()); 2598 } 2599 2600 //===----------------------------------------------------------------------===// 2601 // VerifyOp 2602 //===----------------------------------------------------------------------===// 2603 2604 DiagnosedSilenceableFailure 2605 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter, 2606 Operation *target, 2607 transform::ApplyToEachResultList &results, 2608 transform::TransformState &state) { 2609 if (failed(::mlir::verify(target))) { 2610 DiagnosedDefiniteFailure diag = emitDefiniteFailure() 2611 << "failed to verify payload op"; 2612 diag.attachNote(target->getLoc()) << "payload op"; 2613 return diag; 2614 } 2615 return DiagnosedSilenceableFailure::success(); 2616 } 2617 2618 void transform::VerifyOp::getEffects( 2619 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2620 transform::onlyReadsHandle(getTarget(), effects); 2621 } 2622 2623 //===----------------------------------------------------------------------===// 2624 // YieldOp 2625 //===----------------------------------------------------------------------===// 2626 2627 void transform::YieldOp::getEffects( 2628 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2629 onlyReadsHandle(getOperands(), effects); 2630 } 2631