xref: /llvm-project/mlir/lib/Dialect/Transform/IR/TransformOps.cpp (revision f6bfbc87779ef2079e9b1356ac21381659f13fbb)
1 //===- TransformOps.cpp - Transform dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
17 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
18 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/Verifier.h"
26 #include "mlir/Interfaces/CallInterfaces.h"
27 #include "mlir/Interfaces/ControlFlowInterfaces.h"
28 #include "mlir/Interfaces/FunctionImplementation.h"
29 #include "mlir/Interfaces/FunctionInterfaces.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Pass/PassRegistry.h"
33 #include "mlir/Transforms/CSE.h"
34 #include "mlir/Transforms/DialectConversion.h"
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
36 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
37 #include "llvm/ADT/DenseSet.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/ScopeExit.h"
40 #include "llvm/ADT/SmallPtrSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/ErrorHandling.h"
44 #include <optional>
45 
46 #define DEBUG_TYPE "transform-dialect"
47 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
48 
49 #define DEBUG_TYPE_MATCHER "transform-matcher"
50 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
51 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
52 
53 using namespace mlir;
54 
55 static ParseResult parseSequenceOpOperands(
56     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
57     Type &rootType,
58     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
59     SmallVectorImpl<Type> &extraBindingTypes);
60 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
61                                     Value root, Type rootType,
62                                     ValueRange extraBindings,
63                                     TypeRange extraBindingTypes);
64 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
65                                      ArrayAttr matchers, ArrayAttr actions);
66 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
67                                             ArrayAttr &matchers,
68                                             ArrayAttr &actions);
69 
70 /// Helper function to check if the given transform op is contained in (or
71 /// equal to) the given payload target op. In that case, an error is returned.
72 /// Transforming transform IR that is currently executing is generally unsafe.
73 static DiagnosedSilenceableFailure
74 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
75                                      Operation *payload) {
76   Operation *transformAncestor = transform.getOperation();
77   while (transformAncestor) {
78     if (transformAncestor == payload) {
79       DiagnosedDefiniteFailure diag =
80           transform.emitDefiniteFailure()
81           << "cannot apply transform to itself (or one of its ancestors)";
82       diag.attachNote(payload->getLoc()) << "target payload op";
83       return diag;
84     }
85     transformAncestor = transformAncestor->getParentOp();
86   }
87   return DiagnosedSilenceableFailure::success();
88 }
89 
90 #define GET_OP_CLASSES
91 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
92 
93 //===----------------------------------------------------------------------===//
94 // AlternativesOp
95 //===----------------------------------------------------------------------===//
96 
97 OperandRange
98 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
99   if (!point.isParent() && getOperation()->getNumOperands() == 1)
100     return getOperation()->getOperands();
101   return OperandRange(getOperation()->operand_end(),
102                       getOperation()->operand_end());
103 }
104 
105 void transform::AlternativesOp::getSuccessorRegions(
106     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
107   for (Region &alternative : llvm::drop_begin(
108            getAlternatives(),
109            point.isParent() ? 0
110                             : point.getRegionOrNull()->getRegionNumber() + 1)) {
111     regions.emplace_back(&alternative, !getOperands().empty()
112                                            ? alternative.getArguments()
113                                            : Block::BlockArgListType());
114   }
115   if (!point.isParent())
116     regions.emplace_back(getOperation()->getResults());
117 }
118 
119 void transform::AlternativesOp::getRegionInvocationBounds(
120     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
121   (void)operands;
122   // The region corresponding to the first alternative is always executed, the
123   // remaining may or may not be executed.
124   bounds.reserve(getNumRegions());
125   bounds.emplace_back(1, 1);
126   bounds.resize(getNumRegions(), InvocationBounds(0, 1));
127 }
128 
129 static void forwardEmptyOperands(Block *block, transform::TransformState &state,
130                                  transform::TransformResults &results) {
131   for (const auto &res : block->getParentOp()->getOpResults())
132     results.set(res, {});
133 }
134 
135 DiagnosedSilenceableFailure
136 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
137                                  transform::TransformResults &results,
138                                  transform::TransformState &state) {
139   SmallVector<Operation *> originals;
140   if (Value scopeHandle = getScope())
141     llvm::append_range(originals, state.getPayloadOps(scopeHandle));
142   else
143     originals.push_back(state.getTopLevel());
144 
145   for (Operation *original : originals) {
146     if (original->isAncestor(getOperation())) {
147       auto diag = emitDefiniteFailure()
148                   << "scope must not contain the transforms being applied";
149       diag.attachNote(original->getLoc()) << "scope";
150       return diag;
151     }
152     if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
153       auto diag = emitDefiniteFailure()
154                   << "only isolated-from-above ops can be alternative scopes";
155       diag.attachNote(original->getLoc()) << "scope";
156       return diag;
157     }
158   }
159 
160   for (Region &reg : getAlternatives()) {
161     // Clone the scope operations and make the transforms in this alternative
162     // region apply to them by virtue of mapping the block argument (the only
163     // visible handle) to the cloned scope operations. This effectively prevents
164     // the transformation from accessing any IR outside the scope.
165     auto scope = state.make_region_scope(reg);
166     auto clones = llvm::to_vector(
167         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
168     auto deleteClones = llvm::make_scope_exit([&] {
169       for (Operation *clone : clones)
170         clone->erase();
171     });
172     if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
173       return DiagnosedSilenceableFailure::definiteFailure();
174 
175     bool failed = false;
176     for (Operation &transform : reg.front().without_terminator()) {
177       DiagnosedSilenceableFailure result =
178           state.applyTransform(cast<TransformOpInterface>(transform));
179       if (result.isSilenceableFailure()) {
180         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
181                           << "\n");
182         failed = true;
183         break;
184       }
185 
186       if (::mlir::failed(result.silence()))
187         return DiagnosedSilenceableFailure::definiteFailure();
188     }
189 
190     // If all operations in the given alternative succeeded, no need to consider
191     // the rest. Replace the original scoping operation with the clone on which
192     // the transformations were performed.
193     if (!failed) {
194       // We will be using the clones, so cancel their scheduled deletion.
195       deleteClones.release();
196       TrackingListener listener(state, *this);
197       IRRewriter rewriter(getContext(), &listener);
198       for (const auto &kvp : llvm::zip(originals, clones)) {
199         Operation *original = std::get<0>(kvp);
200         Operation *clone = std::get<1>(kvp);
201         original->getBlock()->getOperations().insert(original->getIterator(),
202                                                      clone);
203         rewriter.replaceOp(original, clone->getResults());
204       }
205       detail::forwardTerminatorOperands(&reg.front(), state, results);
206       return DiagnosedSilenceableFailure::success();
207     }
208   }
209   return emitSilenceableError() << "all alternatives failed";
210 }
211 
212 void transform::AlternativesOp::getEffects(
213     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
214   consumesHandle(getOperation()->getOpOperands(), effects);
215   producesHandle(getOperation()->getOpResults(), effects);
216   for (Region *region : getRegions()) {
217     if (!region->empty())
218       producesHandle(region->front().getArguments(), effects);
219   }
220   modifiesPayload(effects);
221 }
222 
223 LogicalResult transform::AlternativesOp::verify() {
224   for (Region &alternative : getAlternatives()) {
225     Block &block = alternative.front();
226     Operation *terminator = block.getTerminator();
227     if (terminator->getOperands().getTypes() != getResults().getTypes()) {
228       InFlightDiagnostic diag = emitOpError()
229                                 << "expects terminator operands to have the "
230                                    "same type as results of the operation";
231       diag.attachNote(terminator->getLoc()) << "terminator";
232       return diag;
233     }
234   }
235 
236   return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // AnnotateOp
241 //===----------------------------------------------------------------------===//
242 
243 DiagnosedSilenceableFailure
244 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
245                              transform::TransformResults &results,
246                              transform::TransformState &state) {
247   SmallVector<Operation *> targets =
248       llvm::to_vector(state.getPayloadOps(getTarget()));
249 
250   Attribute attr = UnitAttr::get(getContext());
251   if (auto paramH = getParam()) {
252     ArrayRef<Attribute> params = state.getParams(paramH);
253     if (params.size() != 1) {
254       if (targets.size() != params.size()) {
255         return emitSilenceableError()
256                << "parameter and target have different payload lengths ("
257                << params.size() << " vs " << targets.size() << ")";
258       }
259       for (auto &&[target, attr] : llvm::zip_equal(targets, params))
260         target->setAttr(getName(), attr);
261       return DiagnosedSilenceableFailure::success();
262     }
263     attr = params[0];
264   }
265   for (auto *target : targets)
266     target->setAttr(getName(), attr);
267   return DiagnosedSilenceableFailure::success();
268 }
269 
270 void transform::AnnotateOp::getEffects(
271     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
272   onlyReadsHandle(getTargetMutable(), effects);
273   onlyReadsHandle(getParamMutable(), effects);
274   modifiesPayload(effects);
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // ApplyCommonSubexpressionEliminationOp
279 //===----------------------------------------------------------------------===//
280 
281 DiagnosedSilenceableFailure
282 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
283     transform::TransformRewriter &rewriter, Operation *target,
284     ApplyToEachResultList &results, transform::TransformState &state) {
285   // Make sure that this transform is not applied to itself. Modifying the
286   // transform IR while it is being interpreted is generally dangerous.
287   DiagnosedSilenceableFailure payloadCheck =
288       ensurePayloadIsSeparateFromTransform(*this, target);
289   if (!payloadCheck.succeeded())
290     return payloadCheck;
291 
292   DominanceInfo domInfo;
293   mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
294   return DiagnosedSilenceableFailure::success();
295 }
296 
297 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
298     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
299   transform::onlyReadsHandle(getTargetMutable(), effects);
300   transform::modifiesPayload(effects);
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // ApplyDeadCodeEliminationOp
305 //===----------------------------------------------------------------------===//
306 
307 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
308     transform::TransformRewriter &rewriter, Operation *target,
309     ApplyToEachResultList &results, transform::TransformState &state) {
310   // Make sure that this transform is not applied to itself. Modifying the
311   // transform IR while it is being interpreted is generally dangerous.
312   DiagnosedSilenceableFailure payloadCheck =
313       ensurePayloadIsSeparateFromTransform(*this, target);
314   if (!payloadCheck.succeeded())
315     return payloadCheck;
316 
317   // Maintain a worklist of potentially dead ops.
318   SetVector<Operation *> worklist;
319 
320   // Helper function that adds all defining ops of used values (operands and
321   // operands of nested ops).
322   auto addDefiningOpsToWorklist = [&](Operation *op) {
323     op->walk([&](Operation *op) {
324       for (Value v : op->getOperands())
325         if (Operation *defOp = v.getDefiningOp())
326           if (target->isProperAncestor(defOp))
327             worklist.insert(defOp);
328     });
329   };
330 
331   // Helper function that erases an op.
332   auto eraseOp = [&](Operation *op) {
333     // Remove op and nested ops from the worklist.
334     op->walk([&](Operation *op) {
335       const auto *it = llvm::find(worklist, op);
336       if (it != worklist.end())
337         worklist.erase(it);
338     });
339     rewriter.eraseOp(op);
340   };
341 
342   // Initial walk over the IR.
343   target->walk<WalkOrder::PostOrder>([&](Operation *op) {
344     if (op != target && isOpTriviallyDead(op)) {
345       addDefiningOpsToWorklist(op);
346       eraseOp(op);
347     }
348   });
349 
350   // Erase all ops that have become dead.
351   while (!worklist.empty()) {
352     Operation *op = worklist.pop_back_val();
353     if (!isOpTriviallyDead(op))
354       continue;
355     addDefiningOpsToWorklist(op);
356     eraseOp(op);
357   }
358 
359   return DiagnosedSilenceableFailure::success();
360 }
361 
362 void transform::ApplyDeadCodeEliminationOp::getEffects(
363     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
364   transform::onlyReadsHandle(getTargetMutable(), effects);
365   transform::modifiesPayload(effects);
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // ApplyPatternsOp
370 //===----------------------------------------------------------------------===//
371 
372 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
373     transform::TransformRewriter &rewriter, Operation *target,
374     ApplyToEachResultList &results, transform::TransformState &state) {
375   // Make sure that this transform is not applied to itself. Modifying the
376   // transform IR while it is being interpreted is generally dangerous. Even
377   // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
378   // performs many additional simplifications such as dead code elimination.
379   DiagnosedSilenceableFailure payloadCheck =
380       ensurePayloadIsSeparateFromTransform(*this, target);
381   if (!payloadCheck.succeeded())
382     return payloadCheck;
383 
384   // Gather all specified patterns.
385   MLIRContext *ctx = target->getContext();
386   RewritePatternSet patterns(ctx);
387   if (!getRegion().empty()) {
388     for (Operation &op : getRegion().front()) {
389       cast<transform::PatternDescriptorOpInterface>(&op)
390           .populatePatternsWithState(patterns, state);
391     }
392   }
393 
394   // Configure the GreedyPatternRewriteDriver.
395   GreedyRewriteConfig config;
396   config.listener =
397       static_cast<RewriterBase::Listener *>(rewriter.getListener());
398   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
399 
400   config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
401                              ? GreedyRewriteConfig::kNoLimit
402                              : getMaxIterations();
403   config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
404                               ? GreedyRewriteConfig::kNoLimit
405                               : getMaxNumRewrites();
406 
407   // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
408   // was requested, apply the greedy pattern rewrite only once. (The greedy
409   // pattern rewrite driver already iterates to a fixpoint internally.)
410   bool cseChanged = false;
411   // One or two iterations should be sufficient. Stop iterating after a certain
412   // threshold to make debugging easier.
413   static const int64_t kNumMaxIterations = 50;
414   int64_t iteration = 0;
415   do {
416     LogicalResult result = failure();
417     if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
418       // Op is isolated from above. Apply patterns and also perform region
419       // simplification.
420       result = applyPatternsGreedily(target, frozenPatterns, config);
421     } else {
422       // Manually gather list of ops because the other
423       // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
424       // from above. This way, patterns can be applied to ops that are not
425       // isolated from above. Regions are not being simplified. Furthermore,
426       // only a single greedy rewrite iteration is performed.
427       SmallVector<Operation *> ops;
428       target->walk([&](Operation *nestedOp) {
429         if (target != nestedOp)
430           ops.push_back(nestedOp);
431       });
432       result = applyOpPatternsGreedily(ops, frozenPatterns, config);
433     }
434 
435     // A failure typically indicates that the pattern application did not
436     // converge.
437     if (failed(result)) {
438       return emitSilenceableFailure(target)
439              << "greedy pattern application failed";
440     }
441 
442     if (getApplyCse()) {
443       DominanceInfo domInfo;
444       mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
445                                           &cseChanged);
446     }
447   } while (cseChanged && ++iteration < kNumMaxIterations);
448 
449   if (iteration == kNumMaxIterations)
450     return emitDefiniteFailure() << "fixpoint iteration did not converge";
451 
452   return DiagnosedSilenceableFailure::success();
453 }
454 
455 LogicalResult transform::ApplyPatternsOp::verify() {
456   if (!getRegion().empty()) {
457     for (Operation &op : getRegion().front()) {
458       if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
459         InFlightDiagnostic diag = emitOpError()
460                                   << "expected children ops to implement "
461                                      "PatternDescriptorOpInterface";
462         diag.attachNote(op.getLoc()) << "op without interface";
463         return diag;
464       }
465     }
466   }
467   return success();
468 }
469 
470 void transform::ApplyPatternsOp::getEffects(
471     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
472   transform::onlyReadsHandle(getTargetMutable(), effects);
473   transform::modifiesPayload(effects);
474 }
475 
476 void transform::ApplyPatternsOp::build(
477     OpBuilder &builder, OperationState &result, Value target,
478     function_ref<void(OpBuilder &, Location)> bodyBuilder) {
479   result.addOperands(target);
480 
481   OpBuilder::InsertionGuard g(builder);
482   Region *region = result.addRegion();
483   builder.createBlock(region);
484   if (bodyBuilder)
485     bodyBuilder(builder, result.location);
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // ApplyCanonicalizationPatternsOp
490 //===----------------------------------------------------------------------===//
491 
492 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
493     RewritePatternSet &patterns) {
494   MLIRContext *ctx = patterns.getContext();
495   for (Dialect *dialect : ctx->getLoadedDialects())
496     dialect->getCanonicalizationPatterns(patterns);
497   for (RegisteredOperationName op : ctx->getRegisteredOperations())
498     op.getCanonicalizationPatterns(patterns, ctx);
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // ApplyConversionPatternsOp
503 //===----------------------------------------------------------------------===//
504 
505 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
506     transform::TransformRewriter &rewriter,
507     transform::TransformResults &results, transform::TransformState &state) {
508   MLIRContext *ctx = getContext();
509 
510   // Instantiate the default type converter if a type converter builder is
511   // specified.
512   std::unique_ptr<TypeConverter> defaultTypeConverter;
513   transform::TypeConverterBuilderOpInterface typeConverterBuilder =
514       getDefaultTypeConverter();
515   if (typeConverterBuilder)
516     defaultTypeConverter = typeConverterBuilder.getTypeConverter();
517 
518   // Configure conversion target.
519   ConversionTarget conversionTarget(*getContext());
520   if (getLegalOps())
521     for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
522       conversionTarget.addLegalOp(
523           OperationName(cast<StringAttr>(attr).getValue(), ctx));
524   if (getIllegalOps())
525     for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
526       conversionTarget.addIllegalOp(
527           OperationName(cast<StringAttr>(attr).getValue(), ctx));
528   if (getLegalDialects())
529     for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
530       conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
531   if (getIllegalDialects())
532     for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
533       conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
534 
535   // Gather all specified patterns.
536   RewritePatternSet patterns(ctx);
537   // Need to keep the converters alive until after pattern application because
538   // the patterns take a reference to an object that would otherwise get out of
539   // scope.
540   SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
541   if (!getPatterns().empty()) {
542     for (Operation &op : getPatterns().front()) {
543       auto descriptor =
544           cast<transform::ConversionPatternDescriptorOpInterface>(&op);
545 
546       // Check if this pattern set specifies a type converter.
547       std::unique_ptr<TypeConverter> typeConverter =
548           descriptor.getTypeConverter();
549       TypeConverter *converter = nullptr;
550       if (typeConverter) {
551         keepAliveConverters.emplace_back(std::move(typeConverter));
552         converter = keepAliveConverters.back().get();
553       } else {
554         // No type converter specified: Use the default type converter.
555         if (!defaultTypeConverter) {
556           auto diag = emitDefiniteFailure()
557                       << "pattern descriptor does not specify type "
558                          "converter and apply_conversion_patterns op has "
559                          "no default type converter";
560           diag.attachNote(op.getLoc()) << "pattern descriptor op";
561           return diag;
562         }
563         converter = defaultTypeConverter.get();
564       }
565 
566       // Add descriptor-specific updates to the conversion target, which may
567       // depend on the final type converter. In structural converters, the
568       // legality of types dictates the dynamic legality of an operation.
569       descriptor.populateConversionTargetRules(*converter, conversionTarget);
570 
571       descriptor.populatePatterns(*converter, patterns);
572     }
573   }
574 
575   // Attach a tracking listener if handles should be preserved. We configure the
576   // listener to allow op replacements with different names, as conversion
577   // patterns typically replace ops with replacement ops that have a different
578   // name.
579   TrackingListenerConfig trackingConfig;
580   trackingConfig.requireMatchingReplacementOpName = false;
581   ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
582   ConversionConfig conversionConfig;
583   if (getPreserveHandles())
584     conversionConfig.listener = &trackingListener;
585 
586   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
587   for (Operation *target : state.getPayloadOps(getTarget())) {
588     // Make sure that this transform is not applied to itself. Modifying the
589     // transform IR while it is being interpreted is generally dangerous.
590     DiagnosedSilenceableFailure payloadCheck =
591         ensurePayloadIsSeparateFromTransform(*this, target);
592     if (!payloadCheck.succeeded())
593       return payloadCheck;
594 
595     LogicalResult status = failure();
596     if (getPartialConversion()) {
597       status = applyPartialConversion(target, conversionTarget, frozenPatterns,
598                                       conversionConfig);
599     } else {
600       status = applyFullConversion(target, conversionTarget, frozenPatterns,
601                                    conversionConfig);
602     }
603 
604     // Check dialect conversion state.
605     DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
606     if (failed(status)) {
607       diag = emitSilenceableError() << "dialect conversion failed";
608       diag.attachNote(target->getLoc()) << "target op";
609     }
610 
611     // Check tracking listener error state.
612     DiagnosedSilenceableFailure trackingFailure =
613         trackingListener.checkAndResetError();
614     if (!trackingFailure.succeeded()) {
615       if (diag.succeeded()) {
616         // Tracking failure is the only failure.
617         return trackingFailure;
618       } else {
619         diag.attachNote() << "tracking listener also failed: "
620                           << trackingFailure.getMessage();
621         (void)trackingFailure.silence();
622       }
623     }
624 
625     if (!diag.succeeded())
626       return diag;
627   }
628 
629   return DiagnosedSilenceableFailure::success();
630 }
631 
632 LogicalResult transform::ApplyConversionPatternsOp::verify() {
633   if (getNumRegions() != 1 && getNumRegions() != 2)
634     return emitOpError() << "expected 1 or 2 regions";
635   if (!getPatterns().empty()) {
636     for (Operation &op : getPatterns().front()) {
637       if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
638         InFlightDiagnostic diag =
639             emitOpError() << "expected pattern children ops to implement "
640                              "ConversionPatternDescriptorOpInterface";
641         diag.attachNote(op.getLoc()) << "op without interface";
642         return diag;
643       }
644     }
645   }
646   if (getNumRegions() == 2) {
647     Region &typeConverterRegion = getRegion(1);
648     if (!llvm::hasSingleElement(typeConverterRegion.front()))
649       return emitOpError()
650              << "expected exactly one op in default type converter region";
651     Operation *maybeTypeConverter = &typeConverterRegion.front().front();
652     auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
653         maybeTypeConverter);
654     if (!typeConverterOp) {
655       InFlightDiagnostic diag = emitOpError()
656                                 << "expected default converter child op to "
657                                    "implement TypeConverterBuilderOpInterface";
658       diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
659       return diag;
660     }
661     // Check default type converter type.
662     if (!getPatterns().empty()) {
663       for (Operation &op : getPatterns().front()) {
664         auto descriptor =
665             cast<transform::ConversionPatternDescriptorOpInterface>(&op);
666         if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
667           return failure();
668       }
669     }
670   }
671   return success();
672 }
673 
674 void transform::ApplyConversionPatternsOp::getEffects(
675     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
676   if (!getPreserveHandles()) {
677     transform::consumesHandle(getTargetMutable(), effects);
678   } else {
679     transform::onlyReadsHandle(getTargetMutable(), effects);
680   }
681   transform::modifiesPayload(effects);
682 }
683 
684 void transform::ApplyConversionPatternsOp::build(
685     OpBuilder &builder, OperationState &result, Value target,
686     function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
687     function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
688   result.addOperands(target);
689 
690   {
691     OpBuilder::InsertionGuard g(builder);
692     Region *region1 = result.addRegion();
693     builder.createBlock(region1);
694     if (patternsBodyBuilder)
695       patternsBodyBuilder(builder, result.location);
696   }
697   {
698     OpBuilder::InsertionGuard g(builder);
699     Region *region2 = result.addRegion();
700     builder.createBlock(region2);
701     if (typeConverterBodyBuilder)
702       typeConverterBodyBuilder(builder, result.location);
703   }
704 }
705 
706 //===----------------------------------------------------------------------===//
707 // ApplyToLLVMConversionPatternsOp
708 //===----------------------------------------------------------------------===//
709 
710 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
711     TypeConverter &typeConverter, RewritePatternSet &patterns) {
712   Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
713   assert(dialect && "expected that dialect is loaded");
714   auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
715   // ConversionTarget is currently ignored because the enclosing
716   // apply_conversion_patterns op sets up its own ConversionTarget.
717   ConversionTarget target(*getContext());
718   iface->populateConvertToLLVMConversionPatterns(
719       target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
720 }
721 
722 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
723     transform::TypeConverterBuilderOpInterface builder) {
724   if (builder.getTypeConverterType() != "LLVMTypeConverter")
725     return emitOpError("expected LLVMTypeConverter");
726   return success();
727 }
728 
729 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
730   Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
731   if (!dialect)
732     return emitOpError("unknown dialect or dialect not loaded: ")
733            << getDialectName();
734   auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
735   if (!iface)
736     return emitOpError(
737                "dialect does not implement ConvertToLLVMPatternInterface or "
738                "extension was not loaded: ")
739            << getDialectName();
740   return success();
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // ApplyLoopInvariantCodeMotionOp
745 //===----------------------------------------------------------------------===//
746 
747 DiagnosedSilenceableFailure
748 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
749     transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
750     transform::ApplyToEachResultList &results,
751     transform::TransformState &state) {
752   // Currently, LICM does not remove operations, so we don't need tracking.
753   // If this ever changes, add a LICM entry point that takes a rewriter.
754   moveLoopInvariantCode(target);
755   return DiagnosedSilenceableFailure::success();
756 }
757 
758 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
759     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
760   transform::onlyReadsHandle(getTargetMutable(), effects);
761   transform::modifiesPayload(effects);
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // ApplyRegisteredPassOp
766 //===----------------------------------------------------------------------===//
767 
768 DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
769     transform::TransformRewriter &rewriter, Operation *target,
770     ApplyToEachResultList &results, transform::TransformState &state) {
771   // Make sure that this transform is not applied to itself. Modifying the
772   // transform IR while it is being interpreted is generally dangerous. Even
773   // more so when applying passes because they may perform a wide range of IR
774   // modifications.
775   DiagnosedSilenceableFailure payloadCheck =
776       ensurePayloadIsSeparateFromTransform(*this, target);
777   if (!payloadCheck.succeeded())
778     return payloadCheck;
779 
780   // Get pass or pass pipeline from registry.
781   const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
782   if (!info)
783     info = PassInfo::lookup(getPassName());
784   if (!info)
785     return emitDefiniteFailure()
786            << "unknown pass or pass pipeline: " << getPassName();
787 
788   // Create pass manager and run the pass or pass pipeline.
789   PassManager pm(getContext());
790   if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
791         emitError(msg);
792         return failure();
793       }))) {
794     return emitDefiniteFailure()
795            << "failed to add pass or pass pipeline to pipeline: "
796            << getPassName();
797   }
798   if (failed(pm.run(target))) {
799     auto diag = emitSilenceableError() << "pass pipeline failed";
800     diag.attachNote(target->getLoc()) << "target op";
801     return diag;
802   }
803 
804   results.push_back(target);
805   return DiagnosedSilenceableFailure::success();
806 }
807 
808 //===----------------------------------------------------------------------===//
809 // CastOp
810 //===----------------------------------------------------------------------===//
811 
812 DiagnosedSilenceableFailure
813 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
814                               Operation *target, ApplyToEachResultList &results,
815                               transform::TransformState &state) {
816   results.push_back(target);
817   return DiagnosedSilenceableFailure::success();
818 }
819 
820 void transform::CastOp::getEffects(
821     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
822   onlyReadsPayload(effects);
823   onlyReadsHandle(getInputMutable(), effects);
824   producesHandle(getOperation()->getOpResults(), effects);
825 }
826 
827 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
828   assert(inputs.size() == 1 && "expected one input");
829   assert(outputs.size() == 1 && "expected one output");
830   return llvm::all_of(
831       std::initializer_list<Type>{inputs.front(), outputs.front()},
832       llvm::IsaPred<transform::TransformHandleTypeInterface>);
833 }
834 
835 //===----------------------------------------------------------------------===//
836 // CollectMatchingOp
837 //===----------------------------------------------------------------------===//
838 
839 /// Applies matcher operations from the given `block` using
840 /// `blockArgumentMapping` to initialize block arguments. Updates `state`
841 /// accordingly. If any of the matcher produces a silenceable failure, discards
842 /// it (printing the content to the debug output stream) and returns failure. If
843 /// any of the matchers produces a definite failure, reports it and returns
844 /// failure. If all matchers in the block succeed, populates `mappings` with the
845 /// payload entities associated with the block terminator operands. Note that
846 /// `mappings` will be cleared before that.
847 static DiagnosedSilenceableFailure
848 matchBlock(Block &block,
849            ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
850            transform::TransformState &state,
851            SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
852   assert(block.getParent() && "cannot match using a detached block");
853   auto matchScope = state.make_region_scope(*block.getParent());
854   if (failed(
855           state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
856     return DiagnosedSilenceableFailure::definiteFailure();
857 
858   for (Operation &match : block.without_terminator()) {
859     if (!isa<transform::MatchOpInterface>(match)) {
860       return emitDefiniteFailure(match.getLoc())
861              << "expected operations in the match part to "
862                 "implement MatchOpInterface";
863     }
864     DiagnosedSilenceableFailure diag =
865         state.applyTransform(cast<transform::TransformOpInterface>(match));
866     if (diag.succeeded())
867       continue;
868 
869     return diag;
870   }
871 
872   // Remember the values mapped to the terminator operands so we can
873   // forward them to the action.
874   ValueRange yieldedValues = block.getTerminator()->getOperands();
875   // Our contract with the caller is that the mappings will contain only the
876   // newly mapped values, clear the rest.
877   mappings.clear();
878   transform::detail::prepareValueMappings(mappings, yieldedValues, state);
879   return DiagnosedSilenceableFailure::success();
880 }
881 
882 /// Returns `true` if both types implement one of the interfaces provided as
883 /// template parameters.
884 template <typename... Tys>
885 static bool implementSameInterface(Type t1, Type t2) {
886   return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
887 }
888 
889 /// Returns `true` if both types implement one of the transform dialect
890 /// interfaces.
891 static bool implementSameTransformInterface(Type t1, Type t2) {
892   return implementSameInterface<transform::TransformHandleTypeInterface,
893                                 transform::TransformParamTypeInterface,
894                                 transform::TransformValueHandleTypeInterface>(
895       t1, t2);
896 }
897 
898 //===----------------------------------------------------------------------===//
899 // CollectMatchingOp
900 //===----------------------------------------------------------------------===//
901 
902 DiagnosedSilenceableFailure
903 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
904                                     transform::TransformResults &results,
905                                     transform::TransformState &state) {
906   auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
907       getOperation(), getMatcher());
908   if (matcher.isExternal()) {
909     return emitDefiniteFailure()
910            << "unresolved external symbol " << getMatcher();
911   }
912 
913   SmallVector<SmallVector<MappedValue>, 2> rawResults;
914   rawResults.resize(getOperation()->getNumResults());
915   std::optional<DiagnosedSilenceableFailure> maybeFailure;
916   for (Operation *root : state.getPayloadOps(getRoot())) {
917     WalkResult walkResult = root->walk([&](Operation *op) {
918       DEBUG_MATCHER({
919         DBGS_MATCHER() << "matching ";
920         op->print(llvm::dbgs(),
921                   OpPrintingFlags().assumeVerified().skipRegions());
922         llvm::dbgs() << " @" << op << "\n";
923       });
924 
925       // Try matching.
926       SmallVector<SmallVector<MappedValue>> mappings;
927       SmallVector<transform::MappedValue> inputMapping({op});
928       DiagnosedSilenceableFailure diag = matchBlock(
929           matcher.getFunctionBody().front(),
930           ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
931           mappings);
932       if (diag.isDefiniteFailure())
933         return WalkResult::interrupt();
934       if (diag.isSilenceableFailure()) {
935         DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
936                                      << " failed: " << diag.getMessage());
937         return WalkResult::advance();
938       }
939 
940       // If succeeded, collect results.
941       for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
942         if (mapping.size() != 1) {
943           maybeFailure.emplace(emitSilenceableError()
944                                << "result #" << i << ", associated with "
945                                << mapping.size()
946                                << " payload objects, expected 1");
947           return WalkResult::interrupt();
948         }
949         rawResults[i].push_back(mapping[0]);
950       }
951       return WalkResult::advance();
952     });
953     if (walkResult.wasInterrupted())
954       return std::move(*maybeFailure);
955     assert(!maybeFailure && "failure set but the walk was not interrupted");
956 
957     for (auto &&[opResult, rawResult] :
958          llvm::zip_equal(getOperation()->getResults(), rawResults)) {
959       results.setMappedValues(opResult, rawResult);
960     }
961   }
962   return DiagnosedSilenceableFailure::success();
963 }
964 
965 void transform::CollectMatchingOp::getEffects(
966     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
967   onlyReadsHandle(getRootMutable(), effects);
968   producesHandle(getOperation()->getOpResults(), effects);
969   onlyReadsPayload(effects);
970 }
971 
972 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
973     SymbolTableCollection &symbolTable) {
974   auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
975       symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
976   if (!matcherSymbol ||
977       !isa<TransformOpInterface>(matcherSymbol.getOperation()))
978     return emitError() << "unresolved matcher symbol " << getMatcher();
979 
980   ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
981   if (argumentTypes.size() != 1 ||
982       !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
983     return emitError()
984            << "expected the matcher to take one operation handle argument";
985   }
986   if (!matcherSymbol.getArgAttr(
987           0, transform::TransformDialect::kArgReadOnlyAttrName)) {
988     return emitError() << "expected the matcher argument to be marked readonly";
989   }
990 
991   ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
992   if (resultTypes.size() != getOperation()->getNumResults()) {
993     return emitError()
994            << "expected the matcher to yield as many values as op has results ("
995            << getOperation()->getNumResults() << "), got "
996            << resultTypes.size();
997   }
998 
999   for (auto &&[i, matcherType, resultType] :
1000        llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1001     if (implementSameTransformInterface(matcherType, resultType))
1002       continue;
1003 
1004     return emitError()
1005            << "mismatching type interfaces for matcher result and op result #"
1006            << i;
1007   }
1008 
1009   return success();
1010 }
1011 
1012 //===----------------------------------------------------------------------===//
1013 // ForeachMatchOp
1014 //===----------------------------------------------------------------------===//
1015 
1016 // This is fine because nothing is actually consumed by this op.
1017 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1018 
1019 DiagnosedSilenceableFailure
1020 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1021                                  transform::TransformResults &results,
1022                                  transform::TransformState &state) {
1023   SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>>
1024       matchActionPairs;
1025   matchActionPairs.reserve(getMatchers().size());
1026   SymbolTableCollection symbolTable;
1027   for (auto &&[matcher, action] :
1028        llvm::zip_equal(getMatchers(), getActions())) {
1029     auto matcherSymbol =
1030         symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1031             getOperation(), cast<SymbolRefAttr>(matcher));
1032     auto actionSymbol =
1033         symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1034             getOperation(), cast<SymbolRefAttr>(action));
1035     assert(matcherSymbol && actionSymbol &&
1036            "unresolved symbols not caught by the verifier");
1037 
1038     if (matcherSymbol.isExternal())
1039       return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1040     if (actionSymbol.isExternal())
1041       return emitDefiniteFailure() << "unresolved external symbol " << action;
1042 
1043     matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1044   }
1045 
1046   DiagnosedSilenceableFailure overallDiag =
1047       DiagnosedSilenceableFailure::success();
1048 
1049   SmallVector<SmallVector<MappedValue>> matchInputMapping;
1050   SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1051   SmallVector<SmallVector<MappedValue>> actionResultMapping;
1052   // Explicitly add the mapping for the first block argument (the op being
1053   // matched).
1054   matchInputMapping.emplace_back();
1055   transform::detail::prepareValueMappings(matchInputMapping,
1056                                           getForwardedInputs(), state);
1057   SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1058   actionResultMapping.resize(getForwardedOutputs().size());
1059 
1060   for (Operation *root : state.getPayloadOps(getRoot())) {
1061     WalkResult walkResult = root->walk([&](Operation *op) {
1062       // If getRestrictRoot is not present, skip over the root op itself so we
1063       // don't invalidate it.
1064       if (!getRestrictRoot() && op == root)
1065         return WalkResult::advance();
1066 
1067       DEBUG_MATCHER({
1068         DBGS_MATCHER() << "matching ";
1069         op->print(llvm::dbgs(),
1070                   OpPrintingFlags().assumeVerified().skipRegions());
1071         llvm::dbgs() << " @" << op << "\n";
1072       });
1073 
1074       firstMatchArgument.clear();
1075       firstMatchArgument.push_back(op);
1076 
1077       // Try all the match/action pairs until the first successful match.
1078       for (auto [matcher, action] : matchActionPairs) {
1079         DiagnosedSilenceableFailure diag =
1080             matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1081                        state, matchOutputMapping);
1082         if (diag.isDefiniteFailure())
1083           return WalkResult::interrupt();
1084         if (diag.isSilenceableFailure()) {
1085           DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1086                                        << " failed: " << diag.getMessage());
1087           continue;
1088         }
1089 
1090         auto scope = state.make_region_scope(action.getFunctionBody());
1091         if (failed(state.mapBlockArguments(
1092                 action.getFunctionBody().front().getArguments(),
1093                 matchOutputMapping))) {
1094           return WalkResult::interrupt();
1095         }
1096 
1097         for (Operation &transform :
1098              action.getFunctionBody().front().without_terminator()) {
1099           DiagnosedSilenceableFailure result =
1100               state.applyTransform(cast<TransformOpInterface>(transform));
1101           if (result.isDefiniteFailure())
1102             return WalkResult::interrupt();
1103           if (result.isSilenceableFailure()) {
1104             if (overallDiag.succeeded()) {
1105               overallDiag = emitSilenceableError() << "actions failed";
1106             }
1107             overallDiag.attachNote(action->getLoc())
1108                 << "failed action: " << result.getMessage();
1109             overallDiag.attachNote(op->getLoc())
1110                 << "when applied to this matching payload";
1111             (void)result.silence();
1112             continue;
1113           }
1114         }
1115         if (failed(detail::appendValueMappings(
1116                 MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1117                 action.getFunctionBody().front().getTerminator()->getOperands(),
1118                 state, getFlattenResults()))) {
1119           emitDefiniteFailure()
1120               << "action @" << action.getName()
1121               << " has results associated with multiple payload entities, "
1122                  "but flattening was not requested";
1123           return WalkResult::interrupt();
1124         }
1125         break;
1126       }
1127       return WalkResult::advance();
1128     });
1129     if (walkResult.wasInterrupted())
1130       return DiagnosedSilenceableFailure::definiteFailure();
1131   }
1132 
1133   // The root operation should not have been affected, so we can just reassign
1134   // the payload to the result. Note that we need to consume the root handle to
1135   // make sure any handles to operations inside, that could have been affected
1136   // by actions, are invalidated.
1137   results.set(llvm::cast<OpResult>(getUpdated()),
1138               state.getPayloadOps(getRoot()));
1139   for (auto &&[result, mapping] :
1140        llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1141     results.setMappedValues(result, mapping);
1142   }
1143   return overallDiag;
1144 }
1145 
1146 void transform::ForeachMatchOp::getAsmResultNames(
1147     OpAsmSetValueNameFn setNameFn) {
1148   setNameFn(getUpdated(), "updated_root");
1149   for (Value v : getForwardedOutputs()) {
1150     setNameFn(v, "yielded");
1151   }
1152 }
1153 
1154 void transform::ForeachMatchOp::getEffects(
1155     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1156   // Bail if invalid.
1157   if (getOperation()->getNumOperands() < 1 ||
1158       getOperation()->getNumResults() < 1) {
1159     return modifiesPayload(effects);
1160   }
1161 
1162   consumesHandle(getRootMutable(), effects);
1163   onlyReadsHandle(getForwardedInputsMutable(), effects);
1164   producesHandle(getOperation()->getOpResults(), effects);
1165   modifiesPayload(effects);
1166 }
1167 
1168 /// Parses the comma-separated list of symbol reference pairs of the format
1169 /// `@matcher -> @action`.
1170 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1171                                             ArrayAttr &matchers,
1172                                             ArrayAttr &actions) {
1173   StringAttr matcher;
1174   StringAttr action;
1175   SmallVector<Attribute> matcherList;
1176   SmallVector<Attribute> actionList;
1177   do {
1178     if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1179         parser.parseSymbolName(action)) {
1180       return failure();
1181     }
1182     matcherList.push_back(SymbolRefAttr::get(matcher));
1183     actionList.push_back(SymbolRefAttr::get(action));
1184   } while (parser.parseOptionalComma().succeeded());
1185 
1186   matchers = parser.getBuilder().getArrayAttr(matcherList);
1187   actions = parser.getBuilder().getArrayAttr(actionList);
1188   return success();
1189 }
1190 
1191 /// Prints the comma-separated list of symbol reference pairs of the format
1192 /// `@matcher -> @action`.
1193 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
1194                                      ArrayAttr matchers, ArrayAttr actions) {
1195   printer.increaseIndent();
1196   printer.increaseIndent();
1197   for (auto &&[matcher, action, idx] : llvm::zip_equal(
1198            matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1199     printer.printNewline();
1200     printer << cast<SymbolRefAttr>(matcher) << " -> "
1201             << cast<SymbolRefAttr>(action);
1202     if (idx != matchers.size() - 1)
1203       printer << ", ";
1204   }
1205   printer.decreaseIndent();
1206   printer.decreaseIndent();
1207 }
1208 
1209 LogicalResult transform::ForeachMatchOp::verify() {
1210   if (getMatchers().size() != getActions().size())
1211     return emitOpError() << "expected the same number of matchers and actions";
1212   if (getMatchers().empty())
1213     return emitOpError() << "expected at least one match/action pair";
1214 
1215   llvm::SmallPtrSet<Attribute, 8> matcherNames;
1216   for (Attribute name : getMatchers()) {
1217     if (matcherNames.insert(name).second)
1218       continue;
1219     emitWarning() << "matcher " << name
1220                   << " is used more than once, only the first match will apply";
1221   }
1222 
1223   return success();
1224 }
1225 
1226 /// Checks that the attributes of the function-like operation have correct
1227 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
1228 /// annotations being present even if they can be inferred from the body.
1229 static DiagnosedSilenceableFailure
1230 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1231                                      bool alsoVerifyInternal = false) {
1232   auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1233   llvm::SmallDenseSet<unsigned> consumedArguments;
1234   if (!op.isExternal()) {
1235     transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1236                                          consumedArguments);
1237   }
1238   for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1239     bool isConsumed =
1240         op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1241         nullptr;
1242     bool isReadOnly =
1243         op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1244         nullptr;
1245     if (isConsumed && isReadOnly) {
1246       return transformOp.emitSilenceableError()
1247              << "argument #" << i << " cannot be both readonly and consumed";
1248     }
1249     if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1250       return transformOp.emitSilenceableError()
1251              << "must provide consumed/readonly status for arguments of "
1252                 "external or called ops";
1253     }
1254     if (op.isExternal())
1255       continue;
1256 
1257     if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1258       return transformOp.emitSilenceableError()
1259              << "argument #" << i
1260              << " is consumed in the body but is not marked as such";
1261     }
1262     if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1263       // Cannot use op.emitWarning() here as it would attempt to verify the op
1264       // before printing, resulting in infinite recursion.
1265       emitWarning(op->getLoc())
1266           << "op argument #" << i
1267           << " is not consumed in the body but is marked as consumed";
1268     }
1269   }
1270   return DiagnosedSilenceableFailure::success();
1271 }
1272 
1273 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1274     SymbolTableCollection &symbolTable) {
1275   assert(getMatchers().size() == getActions().size());
1276   auto consumedAttr =
1277       StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1278   for (auto &&[matcher, action] :
1279        llvm::zip_equal(getMatchers(), getActions())) {
1280     // Presence and typing.
1281     auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1282         symbolTable.lookupNearestSymbolFrom(getOperation(),
1283                                             cast<SymbolRefAttr>(matcher)));
1284     auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1285         symbolTable.lookupNearestSymbolFrom(getOperation(),
1286                                             cast<SymbolRefAttr>(action)));
1287     if (!matcherSymbol ||
1288         !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1289       return emitError() << "unresolved matcher symbol " << matcher;
1290     if (!actionSymbol ||
1291         !isa<TransformOpInterface>(actionSymbol.getOperation()))
1292       return emitError() << "unresolved action symbol " << action;
1293 
1294     if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1295                                                     /*emitWarnings=*/false,
1296                                                     /*alsoVerifyInternal=*/true)
1297                    .checkAndReport())) {
1298       return failure();
1299     }
1300     if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
1301                                                     /*emitWarnings=*/false,
1302                                                     /*alsoVerifyInternal=*/true)
1303                    .checkAndReport())) {
1304       return failure();
1305     }
1306 
1307     // Input -> matcher forwarding.
1308     TypeRange operandTypes = getOperandTypes();
1309     TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1310     if (operandTypes.size() != matcherArguments.size()) {
1311       InFlightDiagnostic diag =
1312           emitError() << "the number of operands (" << operandTypes.size()
1313                       << ") doesn't match the number of matcher arguments ("
1314                       << matcherArguments.size() << ") for " << matcher;
1315       diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1316       return diag;
1317     }
1318     for (auto &&[i, operand, argument] :
1319          llvm::enumerate(operandTypes, matcherArguments)) {
1320       if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1321         InFlightDiagnostic diag =
1322             emitOpError()
1323             << "does not expect matcher symbol to consume its operand #" << i;
1324         diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1325         return diag;
1326       }
1327 
1328       if (implementSameTransformInterface(operand, argument))
1329         continue;
1330 
1331       InFlightDiagnostic diag =
1332           emitError()
1333           << "mismatching type interfaces for operand and matcher argument #"
1334           << i << " of matcher " << matcher;
1335       diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1336       return diag;
1337     }
1338 
1339     // Matcher -> action forwarding.
1340     TypeRange matcherResults = matcherSymbol.getResultTypes();
1341     TypeRange actionArguments = actionSymbol.getArgumentTypes();
1342     if (matcherResults.size() != actionArguments.size()) {
1343       return emitError() << "mismatching number of matcher results and "
1344                             "action arguments between "
1345                          << matcher << " (" << matcherResults.size() << ") and "
1346                          << action << " (" << actionArguments.size() << ")";
1347     }
1348     for (auto &&[i, matcherType, actionType] :
1349          llvm::enumerate(matcherResults, actionArguments)) {
1350       if (implementSameTransformInterface(matcherType, actionType))
1351         continue;
1352 
1353       return emitError() << "mismatching type interfaces for matcher result "
1354                             "and action argument #"
1355                          << i << "of matcher " << matcher << " and action "
1356                          << action;
1357     }
1358 
1359     // Action -> result forwarding.
1360     TypeRange actionResults = actionSymbol.getResultTypes();
1361     auto resultTypes = TypeRange(getResultTypes()).drop_front();
1362     if (actionResults.size() != resultTypes.size()) {
1363       InFlightDiagnostic diag =
1364           emitError() << "the number of action results ("
1365                       << actionResults.size() << ") for " << action
1366                       << " doesn't match the number of extra op results ("
1367                       << resultTypes.size() << ")";
1368       diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1369       return diag;
1370     }
1371     for (auto &&[i, resultType, actionType] :
1372          llvm::enumerate(resultTypes, actionResults)) {
1373       if (implementSameTransformInterface(resultType, actionType))
1374         continue;
1375 
1376       InFlightDiagnostic diag =
1377           emitError() << "mismatching type interfaces for action result #" << i
1378                       << " of action " << action << " and op result";
1379       diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1380       return diag;
1381     }
1382   }
1383   return success();
1384 }
1385 
1386 //===----------------------------------------------------------------------===//
1387 // ForeachOp
1388 //===----------------------------------------------------------------------===//
1389 
1390 DiagnosedSilenceableFailure
1391 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1392                             transform::TransformResults &results,
1393                             transform::TransformState &state) {
1394   // We store the payloads before executing the body as ops may be removed from
1395   // the mapping by the TrackingRewriter while iteration is in progress.
1396   SmallVector<SmallVector<MappedValue>> payloads;
1397   detail::prepareValueMappings(payloads, getTargets(), state);
1398   size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1399   bool withZipShortest = getWithZipShortest();
1400 
1401   // In case of `zip_shortest`, set the number of iterations to the
1402   // smallest payload in the targets.
1403   if (withZipShortest) {
1404     numIterations =
1405         llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
1406                                         const SmallVector<MappedValue> &B) {
1407           return A.size() < B.size();
1408         })->size();
1409 
1410     for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1411       payloads[argIdx].resize(numIterations);
1412   }
1413 
1414   // As we will be "zipping" over them, check all payloads have the same size.
1415   // `zip_shortest` adjusts all payloads to the same size, so skip this check
1416   // when true.
1417   for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1418        argIdx++) {
1419     if (payloads[argIdx].size() != numIterations) {
1420       return emitSilenceableError()
1421              << "prior targets' payload size (" << numIterations
1422              << ") differs from payload size (" << payloads[argIdx].size()
1423              << ") of target " << getTargets()[argIdx];
1424     }
1425   }
1426 
1427   // Start iterating, indexing into payloads to obtain the right arguments to
1428   // call the body with - each slice of payloads at the same argument index
1429   // corresponding to a tuple to use as the body's block arguments.
1430   ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1431   SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1432   for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1433     auto scope = state.make_region_scope(getBody());
1434     // Set up arguments to the region's block.
1435     for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1436       MappedValue argument = payloads[argIdx][iterIdx];
1437       // Note that each blockArg's handle gets associated with just a single
1438       // element from the corresponding target's payload.
1439       if (failed(state.mapBlockArgument(blockArg, {argument})))
1440         return DiagnosedSilenceableFailure::definiteFailure();
1441     }
1442 
1443     // Execute loop body.
1444     for (Operation &transform : getBody().front().without_terminator()) {
1445       DiagnosedSilenceableFailure result = state.applyTransform(
1446           llvm::cast<transform::TransformOpInterface>(transform));
1447       if (!result.succeeded())
1448         return result;
1449     }
1450 
1451     // Append yielded payloads to corresponding results from prior iterations.
1452     OperandRange yieldOperands = getYieldOp().getOperands();
1453     for (auto &&[result, yieldOperand, resTuple] :
1454          llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1455       // NB: each iteration we add any number of ops/vals/params to a result.
1456       if (isa<TransformHandleTypeInterface>(result.getType()))
1457         llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1458       else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1459         llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1460       else if (isa<TransformParamTypeInterface>(result.getType()))
1461         llvm::append_range(resTuple, state.getParams(yieldOperand));
1462       else
1463         assert(false && "unhandled handle type");
1464   }
1465 
1466   // Associate the accumulated result payloads to the op's actual results.
1467   for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1468     results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1469 
1470   return DiagnosedSilenceableFailure::success();
1471 }
1472 
1473 void transform::ForeachOp::getEffects(
1474     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1475   // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1476   // arity errors, this method might get called before/in absence of `verify()`.
1477   for (auto &&[target, blockArg] :
1478        llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1479     BlockArgument blockArgument = blockArg;
1480     if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1481           return isHandleConsumed(blockArgument,
1482                                   cast<TransformOpInterface>(&op));
1483         })) {
1484       consumesHandle(target, effects);
1485     } else {
1486       onlyReadsHandle(target, effects);
1487     }
1488   }
1489 
1490   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1491         return doesModifyPayload(cast<TransformOpInterface>(&op));
1492       })) {
1493     modifiesPayload(effects);
1494   } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1495                return doesReadPayload(cast<TransformOpInterface>(&op));
1496              })) {
1497     onlyReadsPayload(effects);
1498   }
1499 
1500   producesHandle(getOperation()->getOpResults(), effects);
1501 }
1502 
1503 void transform::ForeachOp::getSuccessorRegions(
1504     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1505   Region *bodyRegion = &getBody();
1506   if (point.isParent()) {
1507     regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1508     return;
1509   }
1510 
1511   // Branch back to the region or the parent.
1512   assert(point == getBody() && "unexpected region index");
1513   regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1514   regions.emplace_back();
1515 }
1516 
1517 OperandRange
1518 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1519   // Each block argument handle is mapped to a subset (one op to be precise)
1520   // of the payload of the corresponding `targets` operand of ForeachOp.
1521   assert(point == getBody() && "unexpected region index");
1522   return getOperation()->getOperands();
1523 }
1524 
1525 transform::YieldOp transform::ForeachOp::getYieldOp() {
1526   return cast<transform::YieldOp>(getBody().front().getTerminator());
1527 }
1528 
1529 LogicalResult transform::ForeachOp::verify() {
1530   for (auto [targetOpt, bodyArgOpt] :
1531        llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1532     if (!targetOpt || !bodyArgOpt)
1533       return emitOpError() << "expects the same number of targets as the body "
1534                               "has block arguments";
1535     if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1536       return emitOpError(
1537           "expects co-indexed targets and the body's "
1538           "block arguments to have the same op/value/param type");
1539   }
1540 
1541   for (auto [resultOpt, yieldOperandOpt] :
1542        llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1543     if (!resultOpt || !yieldOperandOpt)
1544       return emitOpError() << "expects the same number of results as the "
1545                               "yield terminator has operands";
1546     if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1547       return emitOpError("expects co-indexed results and yield "
1548                          "operands to have the same op/value/param type");
1549   }
1550 
1551   return success();
1552 }
1553 
1554 //===----------------------------------------------------------------------===//
1555 // GetParentOp
1556 //===----------------------------------------------------------------------===//
1557 
1558 DiagnosedSilenceableFailure
1559 transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1560                               transform::TransformResults &results,
1561                               transform::TransformState &state) {
1562   SmallVector<Operation *> parents;
1563   DenseSet<Operation *> resultSet;
1564   for (Operation *target : state.getPayloadOps(getTarget())) {
1565     Operation *parent = target;
1566     for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1567       parent = parent->getParentOp();
1568       while (parent) {
1569         bool checkIsolatedFromAbove =
1570             !getIsolatedFromAbove() ||
1571             parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
1572         bool checkOpName = !getOpName().has_value() ||
1573                            parent->getName().getStringRef() == *getOpName();
1574         if (checkIsolatedFromAbove && checkOpName)
1575           break;
1576         parent = parent->getParentOp();
1577       }
1578       if (!parent) {
1579         if (getAllowEmptyResults()) {
1580           results.set(llvm::cast<OpResult>(getResult()), parents);
1581           return DiagnosedSilenceableFailure::success();
1582         }
1583         DiagnosedSilenceableFailure diag =
1584             emitSilenceableError()
1585             << "could not find a parent op that matches all requirements";
1586         diag.attachNote(target->getLoc()) << "target op";
1587         return diag;
1588       }
1589     }
1590     if (getDeduplicate()) {
1591       if (resultSet.insert(parent).second)
1592         parents.push_back(parent);
1593     } else {
1594       parents.push_back(parent);
1595     }
1596   }
1597   results.set(llvm::cast<OpResult>(getResult()), parents);
1598   return DiagnosedSilenceableFailure::success();
1599 }
1600 
1601 //===----------------------------------------------------------------------===//
1602 // GetConsumersOfResult
1603 //===----------------------------------------------------------------------===//
1604 
1605 DiagnosedSilenceableFailure
1606 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1607                                        transform::TransformResults &results,
1608                                        transform::TransformState &state) {
1609   int64_t resultNumber = getResultNumber();
1610   auto payloadOps = state.getPayloadOps(getTarget());
1611   if (std::empty(payloadOps)) {
1612     results.set(cast<OpResult>(getResult()), {});
1613     return DiagnosedSilenceableFailure::success();
1614   }
1615   if (!llvm::hasSingleElement(payloadOps))
1616     return emitDefiniteFailure()
1617            << "handle must be mapped to exactly one payload op";
1618 
1619   Operation *target = *payloadOps.begin();
1620   if (target->getNumResults() <= resultNumber)
1621     return emitDefiniteFailure() << "result number overflow";
1622   results.set(llvm::cast<OpResult>(getResult()),
1623               llvm::to_vector(target->getResult(resultNumber).getUsers()));
1624   return DiagnosedSilenceableFailure::success();
1625 }
1626 
1627 //===----------------------------------------------------------------------===//
1628 // GetDefiningOp
1629 //===----------------------------------------------------------------------===//
1630 
1631 DiagnosedSilenceableFailure
1632 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1633                                 transform::TransformResults &results,
1634                                 transform::TransformState &state) {
1635   SmallVector<Operation *> definingOps;
1636   for (Value v : state.getPayloadValues(getTarget())) {
1637     if (llvm::isa<BlockArgument>(v)) {
1638       DiagnosedSilenceableFailure diag =
1639           emitSilenceableError() << "cannot get defining op of block argument";
1640       diag.attachNote(v.getLoc()) << "target value";
1641       return diag;
1642     }
1643     definingOps.push_back(v.getDefiningOp());
1644   }
1645   results.set(llvm::cast<OpResult>(getResult()), definingOps);
1646   return DiagnosedSilenceableFailure::success();
1647 }
1648 
1649 //===----------------------------------------------------------------------===//
1650 // GetProducerOfOperand
1651 //===----------------------------------------------------------------------===//
1652 
1653 DiagnosedSilenceableFailure
1654 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1655                                        transform::TransformResults &results,
1656                                        transform::TransformState &state) {
1657   int64_t operandNumber = getOperandNumber();
1658   SmallVector<Operation *> producers;
1659   for (Operation *target : state.getPayloadOps(getTarget())) {
1660     Operation *producer =
1661         target->getNumOperands() <= operandNumber
1662             ? nullptr
1663             : target->getOperand(operandNumber).getDefiningOp();
1664     if (!producer) {
1665       DiagnosedSilenceableFailure diag =
1666           emitSilenceableError()
1667           << "could not find a producer for operand number: " << operandNumber
1668           << " of " << *target;
1669       diag.attachNote(target->getLoc()) << "target op";
1670       return diag;
1671     }
1672     producers.push_back(producer);
1673   }
1674   results.set(llvm::cast<OpResult>(getResult()), producers);
1675   return DiagnosedSilenceableFailure::success();
1676 }
1677 
1678 //===----------------------------------------------------------------------===//
1679 // GetOperandOp
1680 //===----------------------------------------------------------------------===//
1681 
1682 DiagnosedSilenceableFailure
1683 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1684                                transform::TransformResults &results,
1685                                transform::TransformState &state) {
1686   SmallVector<Value> operands;
1687   for (Operation *target : state.getPayloadOps(getTarget())) {
1688     SmallVector<int64_t> operandPositions;
1689     DiagnosedSilenceableFailure diag = expandTargetSpecification(
1690         getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1691         target->getNumOperands(), operandPositions);
1692     if (diag.isSilenceableFailure()) {
1693       diag.attachNote(target->getLoc())
1694           << "while considering positions of this payload operation";
1695       return diag;
1696     }
1697     llvm::append_range(operands,
1698                        llvm::map_range(operandPositions, [&](int64_t pos) {
1699                          return target->getOperand(pos);
1700                        }));
1701   }
1702   results.setValues(cast<OpResult>(getResult()), operands);
1703   return DiagnosedSilenceableFailure::success();
1704 }
1705 
1706 LogicalResult transform::GetOperandOp::verify() {
1707   return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1708                                     getIsInverted(), getIsAll());
1709 }
1710 
1711 //===----------------------------------------------------------------------===//
1712 // GetResultOp
1713 //===----------------------------------------------------------------------===//
1714 
1715 DiagnosedSilenceableFailure
1716 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1717                               transform::TransformResults &results,
1718                               transform::TransformState &state) {
1719   SmallVector<Value> opResults;
1720   for (Operation *target : state.getPayloadOps(getTarget())) {
1721     SmallVector<int64_t> resultPositions;
1722     DiagnosedSilenceableFailure diag = expandTargetSpecification(
1723         getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1724         target->getNumResults(), resultPositions);
1725     if (diag.isSilenceableFailure()) {
1726       diag.attachNote(target->getLoc())
1727           << "while considering positions of this payload operation";
1728       return diag;
1729     }
1730     llvm::append_range(opResults,
1731                        llvm::map_range(resultPositions, [&](int64_t pos) {
1732                          return target->getResult(pos);
1733                        }));
1734   }
1735   results.setValues(cast<OpResult>(getResult()), opResults);
1736   return DiagnosedSilenceableFailure::success();
1737 }
1738 
1739 LogicalResult transform::GetResultOp::verify() {
1740   return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1741                                     getIsInverted(), getIsAll());
1742 }
1743 
1744 //===----------------------------------------------------------------------===//
1745 // GetTypeOp
1746 //===----------------------------------------------------------------------===//
1747 
1748 void transform::GetTypeOp::getEffects(
1749     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1750   onlyReadsHandle(getValueMutable(), effects);
1751   producesHandle(getOperation()->getOpResults(), effects);
1752   onlyReadsPayload(effects);
1753 }
1754 
1755 DiagnosedSilenceableFailure
1756 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1757                             transform::TransformResults &results,
1758                             transform::TransformState &state) {
1759   SmallVector<Attribute> params;
1760   for (Value value : state.getPayloadValues(getValue())) {
1761     Type type = value.getType();
1762     if (getElemental()) {
1763       if (auto shaped = dyn_cast<ShapedType>(type)) {
1764         type = shaped.getElementType();
1765       }
1766     }
1767     params.push_back(TypeAttr::get(type));
1768   }
1769   results.setParams(cast<OpResult>(getResult()), params);
1770   return DiagnosedSilenceableFailure::success();
1771 }
1772 
1773 //===----------------------------------------------------------------------===//
1774 // IncludeOp
1775 //===----------------------------------------------------------------------===//
1776 
1777 /// Applies the transform ops contained in `block`. Maps `results` to the same
1778 /// values as the operands of the block terminator.
1779 static DiagnosedSilenceableFailure
1780 applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
1781                    transform::TransformState &state,
1782                    transform::TransformResults &results) {
1783   // Apply the sequenced ops one by one.
1784   for (Operation &transform : block.without_terminator()) {
1785     DiagnosedSilenceableFailure result =
1786         state.applyTransform(cast<transform::TransformOpInterface>(transform));
1787     if (result.isDefiniteFailure())
1788       return result;
1789 
1790     if (result.isSilenceableFailure()) {
1791       if (mode == transform::FailurePropagationMode::Propagate) {
1792         // Propagate empty results in case of early exit.
1793         forwardEmptyOperands(&block, state, results);
1794         return result;
1795       }
1796       (void)result.silence();
1797     }
1798   }
1799 
1800   // Forward the operation mapping for values yielded from the sequence to the
1801   // values produced by the sequence op.
1802   transform::detail::forwardTerminatorOperands(&block, state, results);
1803   return DiagnosedSilenceableFailure::success();
1804 }
1805 
1806 DiagnosedSilenceableFailure
1807 transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
1808                             transform::TransformResults &results,
1809                             transform::TransformState &state) {
1810   auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1811       getOperation(), getTarget());
1812   assert(callee && "unverified reference to unknown symbol");
1813 
1814   if (callee.isExternal())
1815     return emitDefiniteFailure() << "unresolved external named sequence";
1816 
1817   // Map operands to block arguments.
1818   SmallVector<SmallVector<MappedValue>> mappings;
1819   detail::prepareValueMappings(mappings, getOperands(), state);
1820   auto scope = state.make_region_scope(callee.getBody());
1821   for (auto &&[arg, map] :
1822        llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1823     if (failed(state.mapBlockArgument(arg, map)))
1824       return DiagnosedSilenceableFailure::definiteFailure();
1825   }
1826 
1827   DiagnosedSilenceableFailure result = applySequenceBlock(
1828       callee.getBody().front(), getFailurePropagationMode(), state, results);
1829   mappings.clear();
1830   detail::prepareValueMappings(
1831       mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1832   for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1833     results.setMappedValues(result, mapping);
1834   return result;
1835 }
1836 
1837 static DiagnosedSilenceableFailure
1838 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
1839 
1840 void transform::IncludeOp::getEffects(
1841     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1842   // Always mark as modifying the payload.
1843   // TODO: a mechanism to annotate effects on payload. Even when all handles are
1844   // only read, the payload may still be modified, so we currently stay on the
1845   // conservative side and always indicate modification. This may prevent some
1846   // code reordering.
1847   modifiesPayload(effects);
1848 
1849   // Results are always produced.
1850   producesHandle(getOperation()->getOpResults(), effects);
1851 
1852   // Adds default effects to operands and results. This will be added if
1853   // preconditions fail so the trait verifier doesn't complain about missing
1854   // effects and the real precondition failure is reported later on.
1855   auto defaultEffects = [&] {
1856     onlyReadsHandle(getOperation()->getOpOperands(), effects);
1857   };
1858 
1859   // Bail if the callee is unknown. This may run as part of the verification
1860   // process before we verified the validity of the callee or of this op.
1861   auto target =
1862       getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1863   if (!target)
1864     return defaultEffects();
1865   auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1866       getOperation(), getTarget());
1867   if (!callee)
1868     return defaultEffects();
1869   DiagnosedSilenceableFailure earlyVerifierResult =
1870       verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
1871   if (!earlyVerifierResult.succeeded()) {
1872     (void)earlyVerifierResult.silence();
1873     return defaultEffects();
1874   }
1875 
1876   for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1877     if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1878       consumesHandle(getOperation()->getOpOperand(i), effects);
1879     else
1880       onlyReadsHandle(getOperation()->getOpOperand(i), effects);
1881   }
1882 }
1883 
1884 LogicalResult
1885 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1886   // Access through indirection and do additional checking because this may be
1887   // running before the main op verifier.
1888   auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
1889   if (!targetAttr)
1890     return emitOpError() << "expects a 'target' symbol reference attribute";
1891 
1892   auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
1893       *this, targetAttr);
1894   if (!target)
1895     return emitOpError() << "does not reference a named transform sequence";
1896 
1897   FunctionType fnType = target.getFunctionType();
1898   if (fnType.getNumInputs() != getNumOperands())
1899     return emitError("incorrect number of operands for callee");
1900 
1901   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1902     if (getOperand(i).getType() != fnType.getInput(i)) {
1903       return emitOpError("operand type mismatch: expected operand type ")
1904              << fnType.getInput(i) << ", but provided "
1905              << getOperand(i).getType() << " for operand number " << i;
1906     }
1907   }
1908 
1909   if (fnType.getNumResults() != getNumResults())
1910     return emitError("incorrect number of results for callee");
1911 
1912   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1913     Type resultType = getResult(i).getType();
1914     Type funcType = fnType.getResult(i);
1915     if (!implementSameTransformInterface(resultType, funcType)) {
1916       return emitOpError() << "type of result #" << i
1917                            << " must implement the same transform dialect "
1918                               "interface as the corresponding callee result";
1919     }
1920   }
1921 
1922   return verifyFunctionLikeConsumeAnnotations(
1923              cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
1924              /*alsoVerifyInternal=*/true)
1925       .checkAndReport();
1926 }
1927 
1928 //===----------------------------------------------------------------------===//
1929 // MatchOperationEmptyOp
1930 //===----------------------------------------------------------------------===//
1931 
1932 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
1933     ::std::optional<::mlir::Operation *> maybeCurrent,
1934     transform::TransformResults &results, transform::TransformState &state) {
1935   if (!maybeCurrent.has_value()) {
1936     DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
1937     return DiagnosedSilenceableFailure::success();
1938   }
1939   DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
1940   return emitSilenceableError() << "operation is not empty";
1941 }
1942 
1943 //===----------------------------------------------------------------------===//
1944 // MatchOperationNameOp
1945 //===----------------------------------------------------------------------===//
1946 
1947 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
1948     Operation *current, transform::TransformResults &results,
1949     transform::TransformState &state) {
1950   StringRef currentOpName = current->getName().getStringRef();
1951   for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1952     if (acceptedAttr.getValue() == currentOpName)
1953       return DiagnosedSilenceableFailure::success();
1954   }
1955   return emitSilenceableError() << "wrong operation name";
1956 }
1957 
1958 //===----------------------------------------------------------------------===//
1959 // MatchParamCmpIOp
1960 //===----------------------------------------------------------------------===//
1961 
1962 DiagnosedSilenceableFailure
1963 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
1964                                    transform::TransformResults &results,
1965                                    transform::TransformState &state) {
1966   auto signedAPIntAsString = [&](const APInt &value) {
1967     std::string str;
1968     llvm::raw_string_ostream os(str);
1969     value.print(os, /*isSigned=*/true);
1970     return str;
1971   };
1972 
1973   ArrayRef<Attribute> params = state.getParams(getParam());
1974   ArrayRef<Attribute> references = state.getParams(getReference());
1975 
1976   if (params.size() != references.size()) {
1977     return emitSilenceableError()
1978            << "parameters have different payload lengths (" << params.size()
1979            << " vs " << references.size() << ")";
1980   }
1981 
1982   for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
1983     auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1984     auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1985     if (!intAttr || !refAttr) {
1986       return emitDefiniteFailure()
1987              << "non-integer parameter value not expected";
1988     }
1989     if (intAttr.getType() != refAttr.getType()) {
1990       return emitDefiniteFailure()
1991              << "mismatching integer attribute types in parameter #" << i;
1992     }
1993     APInt value = intAttr.getValue();
1994     APInt refValue = refAttr.getValue();
1995 
1996     // TODO: this copy will not be necessary in C++20.
1997     int64_t position = i;
1998     auto reportError = [&](StringRef direction) {
1999       DiagnosedSilenceableFailure diag =
2000           emitSilenceableError() << "expected parameter to be " << direction
2001                                  << " " << signedAPIntAsString(refValue)
2002                                  << ", got " << signedAPIntAsString(value);
2003       diag.attachNote(getParam().getLoc())
2004           << "value # " << position
2005           << " associated with the parameter defined here";
2006       return diag;
2007     };
2008 
2009     switch (getPredicate()) {
2010     case MatchCmpIPredicate::eq:
2011       if (value.eq(refValue))
2012         break;
2013       return reportError("equal to");
2014     case MatchCmpIPredicate::ne:
2015       if (value.ne(refValue))
2016         break;
2017       return reportError("not equal to");
2018     case MatchCmpIPredicate::lt:
2019       if (value.slt(refValue))
2020         break;
2021       return reportError("less than");
2022     case MatchCmpIPredicate::le:
2023       if (value.sle(refValue))
2024         break;
2025       return reportError("less than or equal to");
2026     case MatchCmpIPredicate::gt:
2027       if (value.sgt(refValue))
2028         break;
2029       return reportError("greater than");
2030     case MatchCmpIPredicate::ge:
2031       if (value.sge(refValue))
2032         break;
2033       return reportError("greater than or equal to");
2034     }
2035   }
2036   return DiagnosedSilenceableFailure::success();
2037 }
2038 
2039 void transform::MatchParamCmpIOp::getEffects(
2040     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2041   onlyReadsHandle(getParamMutable(), effects);
2042   onlyReadsHandle(getReferenceMutable(), effects);
2043 }
2044 
2045 //===----------------------------------------------------------------------===//
2046 // ParamConstantOp
2047 //===----------------------------------------------------------------------===//
2048 
2049 DiagnosedSilenceableFailure
2050 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2051                                   transform::TransformResults &results,
2052                                   transform::TransformState &state) {
2053   results.setParams(cast<OpResult>(getParam()), {getValue()});
2054   return DiagnosedSilenceableFailure::success();
2055 }
2056 
2057 //===----------------------------------------------------------------------===//
2058 // MergeHandlesOp
2059 //===----------------------------------------------------------------------===//
2060 
2061 DiagnosedSilenceableFailure
2062 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2063                                  transform::TransformResults &results,
2064                                  transform::TransformState &state) {
2065   ValueRange handles = getHandles();
2066   if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2067     SmallVector<Operation *> operations;
2068     for (Value operand : handles)
2069       llvm::append_range(operations, state.getPayloadOps(operand));
2070     if (!getDeduplicate()) {
2071       results.set(llvm::cast<OpResult>(getResult()), operations);
2072       return DiagnosedSilenceableFailure::success();
2073     }
2074 
2075     SetVector<Operation *> uniqued(operations.begin(), operations.end());
2076     results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2077     return DiagnosedSilenceableFailure::success();
2078   }
2079 
2080   if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2081     SmallVector<Attribute> attrs;
2082     for (Value attribute : handles)
2083       llvm::append_range(attrs, state.getParams(attribute));
2084     if (!getDeduplicate()) {
2085       results.setParams(cast<OpResult>(getResult()), attrs);
2086       return DiagnosedSilenceableFailure::success();
2087     }
2088 
2089     SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
2090     results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2091     return DiagnosedSilenceableFailure::success();
2092   }
2093 
2094   assert(
2095       llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2096       "expected value handle type");
2097   SmallVector<Value> payloadValues;
2098   for (Value value : handles)
2099     llvm::append_range(payloadValues, state.getPayloadValues(value));
2100   if (!getDeduplicate()) {
2101     results.setValues(cast<OpResult>(getResult()), payloadValues);
2102     return DiagnosedSilenceableFailure::success();
2103   }
2104 
2105   SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
2106   results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2107   return DiagnosedSilenceableFailure::success();
2108 }
2109 
2110 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2111   // Handles may be the same if deduplicating is enabled.
2112   return getDeduplicate();
2113 }
2114 
2115 void transform::MergeHandlesOp::getEffects(
2116     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2117   onlyReadsHandle(getHandlesMutable(), effects);
2118   producesHandle(getOperation()->getOpResults(), effects);
2119 
2120   // There are no effects on the Payload IR as this is only a handle
2121   // manipulation.
2122 }
2123 
2124 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2125   if (getDeduplicate() || getHandles().size() != 1)
2126     return {};
2127 
2128   // If deduplication is not required and there is only one operand, it can be
2129   // used directly instead of merging.
2130   return getHandles().front();
2131 }
2132 
2133 //===----------------------------------------------------------------------===//
2134 // NamedSequenceOp
2135 //===----------------------------------------------------------------------===//
2136 
2137 DiagnosedSilenceableFailure
2138 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2139                                   transform::TransformResults &results,
2140                                   transform::TransformState &state) {
2141   if (isExternal())
2142     return emitDefiniteFailure() << "unresolved external named sequence";
2143 
2144   // Map the entry block argument to the list of operations.
2145   // Note: this is the same implementation as PossibleTopLevelTransformOp but
2146   // without attaching the interface / trait since that is tailored to a
2147   // dangling top-level op that does not get "called".
2148   auto scope = state.make_region_scope(getBody());
2149   if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2150           state, this->getOperation(), getBody())))
2151     return DiagnosedSilenceableFailure::definiteFailure();
2152 
2153   return applySequenceBlock(getBody().front(),
2154                             FailurePropagationMode::Propagate, state, results);
2155 }
2156 
2157 void transform::NamedSequenceOp::getEffects(
2158     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2159 
2160 ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
2161                                               OperationState &result) {
2162   return function_interface_impl::parseFunctionOp(
2163       parser, result, /*allowVariadic=*/false,
2164       getFunctionTypeAttrName(result.name),
2165       [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2166          function_interface_impl::VariadicFlag,
2167          std::string &) { return builder.getFunctionType(inputs, results); },
2168       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2169 }
2170 
2171 void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
2172   function_interface_impl::printFunctionOp(
2173       printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2174       getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2175       getResAttrsAttrName());
2176 }
2177 
2178 /// Verifies that a symbol function-like transform dialect operation has the
2179 /// signature and the terminator that have conforming types, i.e., types
2180 /// implementing the same transform dialect type interface. If `allowExternal`
2181 /// is set, allow external symbols (declarations) and don't check the terminator
2182 /// as it may not exist.
2183 static DiagnosedSilenceableFailure
2184 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2185   if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2186     DiagnosedSilenceableFailure diag =
2187         emitSilenceableFailure(op)
2188         << "cannot be defined inside another transform op";
2189     diag.attachNote(parent.getLoc()) << "ancestor transform op";
2190     return diag;
2191   }
2192 
2193   if (op.isExternal() || op.getFunctionBody().empty()) {
2194     if (allowExternal)
2195       return DiagnosedSilenceableFailure::success();
2196 
2197     return emitSilenceableFailure(op) << "cannot be external";
2198   }
2199 
2200   if (op.getFunctionBody().front().empty())
2201     return emitSilenceableFailure(op) << "expected a non-empty body block";
2202 
2203   Operation *terminator = &op.getFunctionBody().front().back();
2204   if (!isa<transform::YieldOp>(terminator)) {
2205     DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
2206                                        << "expected '"
2207                                        << transform::YieldOp::getOperationName()
2208                                        << "' as terminator";
2209     diag.attachNote(terminator->getLoc()) << "terminator";
2210     return diag;
2211   }
2212 
2213   if (terminator->getNumOperands() != op.getResultTypes().size()) {
2214     return emitSilenceableFailure(terminator)
2215            << "expected terminator to have as many operands as the parent op "
2216               "has results";
2217   }
2218   for (auto [i, operandType, resultType] : llvm::zip_equal(
2219            llvm::seq<unsigned>(0, terminator->getNumOperands()),
2220            terminator->getOperands().getType(), op.getResultTypes())) {
2221     if (operandType == resultType)
2222       continue;
2223     return emitSilenceableFailure(terminator)
2224            << "the type of the terminator operand #" << i
2225            << " must match the type of the corresponding parent op result ("
2226            << operandType << " vs " << resultType << ")";
2227   }
2228 
2229   return DiagnosedSilenceableFailure::success();
2230 }
2231 
2232 /// Verification of a NamedSequenceOp. This does not report the error
2233 /// immediately, so it can be used to check for op's well-formedness before the
2234 /// verifier runs, e.g., during trait verification.
2235 static DiagnosedSilenceableFailure
2236 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2237   if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2238     if (!parent->getAttr(
2239             transform::TransformDialect::kWithNamedSequenceAttrName)) {
2240       DiagnosedSilenceableFailure diag =
2241           emitSilenceableFailure(op)
2242           << "expects the parent symbol table to have the '"
2243           << transform::TransformDialect::kWithNamedSequenceAttrName
2244           << "' attribute";
2245       diag.attachNote(parent->getLoc()) << "symbol table operation";
2246       return diag;
2247     }
2248   }
2249 
2250   if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2251     DiagnosedSilenceableFailure diag =
2252         emitSilenceableFailure(op)
2253         << "cannot be defined inside another transform op";
2254     diag.attachNote(parent.getLoc()) << "ancestor transform op";
2255     return diag;
2256   }
2257 
2258   if (op.isExternal() || op.getBody().empty())
2259     return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2260                                                 emitWarnings);
2261 
2262   if (op.getBody().front().empty())
2263     return emitSilenceableFailure(op) << "expected a non-empty body block";
2264 
2265   Operation *terminator = &op.getBody().front().back();
2266   if (!isa<transform::YieldOp>(terminator)) {
2267     DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
2268                                        << "expected '"
2269                                        << transform::YieldOp::getOperationName()
2270                                        << "' as terminator";
2271     diag.attachNote(terminator->getLoc()) << "terminator";
2272     return diag;
2273   }
2274 
2275   if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2276     return emitSilenceableFailure(terminator)
2277            << "expected terminator to have as many operands as the parent op "
2278               "has results";
2279   }
2280   for (auto [i, operandType, resultType] :
2281        llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2282                        terminator->getOperands().getType(),
2283                        op.getFunctionType().getResults())) {
2284     if (operandType == resultType)
2285       continue;
2286     return emitSilenceableFailure(terminator)
2287            << "the type of the terminator operand #" << i
2288            << " must match the type of the corresponding parent op result ("
2289            << operandType << " vs " << resultType << ")";
2290   }
2291 
2292   auto funcOp = cast<FunctionOpInterface>(*op);
2293   DiagnosedSilenceableFailure diag =
2294       verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2295   if (!diag.succeeded())
2296     return diag;
2297 
2298   return verifyYieldingSingleBlockOp(funcOp,
2299                                      /*allowExternal=*/true);
2300 }
2301 
2302 LogicalResult transform::NamedSequenceOp::verify() {
2303   // Actual verification happens in a separate function for reusability.
2304   return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2305 }
2306 
2307 template <typename FnTy>
2308 static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2309                               Type bbArgType, TypeRange extraBindingTypes,
2310                               FnTy bodyBuilder) {
2311   SmallVector<Type> types;
2312   types.reserve(1 + extraBindingTypes.size());
2313   types.push_back(bbArgType);
2314   llvm::append_range(types, extraBindingTypes);
2315 
2316   OpBuilder::InsertionGuard guard(builder);
2317   Region *region = state.regions.back().get();
2318   Block *bodyBlock =
2319       builder.createBlock(region, region->begin(), types,
2320                           SmallVector<Location>(types.size(), state.location));
2321 
2322   // Populate body.
2323   builder.setInsertionPointToStart(bodyBlock);
2324   if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2325     bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2326   } else {
2327     bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2328                 bodyBlock->getArguments().drop_front());
2329   }
2330 }
2331 
2332 void transform::NamedSequenceOp::build(OpBuilder &builder,
2333                                        OperationState &state, StringRef symName,
2334                                        Type rootType, TypeRange resultTypes,
2335                                        SequenceBodyBuilderFn bodyBuilder,
2336                                        ArrayRef<NamedAttribute> attrs,
2337                                        ArrayRef<DictionaryAttr> argAttrs) {
2338   state.addAttribute(SymbolTable::getSymbolAttrName(),
2339                      builder.getStringAttr(symName));
2340   state.addAttribute(getFunctionTypeAttrName(state.name),
2341                      TypeAttr::get(FunctionType::get(builder.getContext(),
2342                                                      rootType, resultTypes)));
2343   state.attributes.append(attrs.begin(), attrs.end());
2344   state.addRegion();
2345 
2346   buildSequenceBody(builder, state, rootType,
2347                     /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2348 }
2349 
2350 //===----------------------------------------------------------------------===//
2351 // NumAssociationsOp
2352 //===----------------------------------------------------------------------===//
2353 
2354 DiagnosedSilenceableFailure
2355 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2356                                     transform::TransformResults &results,
2357                                     transform::TransformState &state) {
2358   size_t numAssociations =
2359       llvm::TypeSwitch<Type, size_t>(getHandle().getType())
2360           .Case([&](TransformHandleTypeInterface opHandle) {
2361             return llvm::range_size(state.getPayloadOps(getHandle()));
2362           })
2363           .Case([&](TransformValueHandleTypeInterface valueHandle) {
2364             return llvm::range_size(state.getPayloadValues(getHandle()));
2365           })
2366           .Case([&](TransformParamTypeInterface param) {
2367             return llvm::range_size(state.getParams(getHandle()));
2368           })
2369           .Default([](Type) {
2370             llvm_unreachable("unknown kind of transform dialect type");
2371             return 0;
2372           });
2373   results.setParams(cast<OpResult>(getNum()),
2374                     rewriter.getI64IntegerAttr(numAssociations));
2375   return DiagnosedSilenceableFailure::success();
2376 }
2377 
2378 LogicalResult transform::NumAssociationsOp::verify() {
2379   // Verify that the result type accepts an i64 attribute as payload.
2380   auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2381   return resultType
2382       .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2383       .checkAndReport();
2384 }
2385 
2386 //===----------------------------------------------------------------------===//
2387 // SelectOp
2388 //===----------------------------------------------------------------------===//
2389 
2390 DiagnosedSilenceableFailure
2391 transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2392                            transform::TransformResults &results,
2393                            transform::TransformState &state) {
2394   SmallVector<Operation *> result;
2395   auto payloadOps = state.getPayloadOps(getTarget());
2396   for (Operation *op : payloadOps) {
2397     if (op->getName().getStringRef() == getOpName())
2398       result.push_back(op);
2399   }
2400   results.set(cast<OpResult>(getResult()), result);
2401   return DiagnosedSilenceableFailure::success();
2402 }
2403 
2404 //===----------------------------------------------------------------------===//
2405 // SplitHandleOp
2406 //===----------------------------------------------------------------------===//
2407 
2408 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2409                                      Value target, int64_t numResultHandles) {
2410   result.addOperands(target);
2411   result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2412 }
2413 
2414 DiagnosedSilenceableFailure
2415 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2416                                 transform::TransformResults &results,
2417                                 transform::TransformState &state) {
2418   int64_t numPayloads =
2419       llvm::TypeSwitch<Type, int64_t>(getHandle().getType())
2420           .Case<TransformHandleTypeInterface>([&](auto x) {
2421             return llvm::range_size(state.getPayloadOps(getHandle()));
2422           })
2423           .Case<TransformValueHandleTypeInterface>([&](auto x) {
2424             return llvm::range_size(state.getPayloadValues(getHandle()));
2425           })
2426           .Case<TransformParamTypeInterface>([&](auto x) {
2427             return llvm::range_size(state.getParams(getHandle()));
2428           })
2429           .Default([](auto x) {
2430             llvm_unreachable("unknown transform dialect type interface");
2431             return -1;
2432           });
2433 
2434   auto produceNumOpsError = [&]() {
2435     return emitSilenceableError()
2436            << getHandle() << " expected to contain " << this->getNumResults()
2437            << " payloads but it contains " << numPayloads << " payloads";
2438   };
2439 
2440   // Fail if there are more payload ops than results and no overflow result was
2441   // specified.
2442   if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2443     return produceNumOpsError();
2444 
2445   // Fail if there are more results than payload ops. Unless:
2446   // - "fail_on_payload_too_small" is set to "false", or
2447   // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2448   if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2449       (numPayloads != 0 || !getPassThroughEmptyHandle()))
2450     return produceNumOpsError();
2451 
2452   // Distribute payloads.
2453   SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2454   if (getOverflowResult())
2455     resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2456 
2457   auto container = [&]() {
2458     if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2459       return llvm::map_to_vector(
2460           state.getPayloadOps(getHandle()),
2461           [](Operation *op) -> MappedValue { return op; });
2462     }
2463     if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2464       return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2465                                  [](Value v) -> MappedValue { return v; });
2466     }
2467     assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2468            "unsupported kind of transform dialect type");
2469     return llvm::map_to_vector(state.getParams(getHandle()),
2470                                [](Attribute a) -> MappedValue { return a; });
2471   }();
2472 
2473   for (auto &&en : llvm::enumerate(container)) {
2474     int64_t resultNum = en.index();
2475     if (resultNum >= getNumResults())
2476       resultNum = *getOverflowResult();
2477     resultHandles[resultNum].push_back(en.value());
2478   }
2479 
2480   // Set transform op results.
2481   for (auto &&it : llvm::enumerate(resultHandles))
2482     results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2483                             it.value());
2484 
2485   return DiagnosedSilenceableFailure::success();
2486 }
2487 
2488 void transform::SplitHandleOp::getEffects(
2489     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2490   onlyReadsHandle(getHandleMutable(), effects);
2491   producesHandle(getOperation()->getOpResults(), effects);
2492   // There are no effects on the Payload IR as this is only a handle
2493   // manipulation.
2494 }
2495 
2496 LogicalResult transform::SplitHandleOp::verify() {
2497   if (getOverflowResult().has_value() &&
2498       !(*getOverflowResult() < getNumResults()))
2499     return emitOpError("overflow_result is not a valid result index");
2500 
2501   for (Type resultType : getResultTypes()) {
2502     if (implementSameTransformInterface(getHandle().getType(), resultType))
2503       continue;
2504 
2505     return emitOpError("expects result types to implement the same transform "
2506                        "interface as the operand type");
2507   }
2508 
2509   return success();
2510 }
2511 
2512 //===----------------------------------------------------------------------===//
2513 // ReplicateOp
2514 //===----------------------------------------------------------------------===//
2515 
2516 DiagnosedSilenceableFailure
2517 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2518                               transform::TransformResults &results,
2519                               transform::TransformState &state) {
2520   unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2521   for (const auto &en : llvm::enumerate(getHandles())) {
2522     Value handle = en.value();
2523     if (isa<TransformHandleTypeInterface>(handle.getType())) {
2524       SmallVector<Operation *> current =
2525           llvm::to_vector(state.getPayloadOps(handle));
2526       SmallVector<Operation *> payload;
2527       payload.reserve(numRepetitions * current.size());
2528       for (unsigned i = 0; i < numRepetitions; ++i)
2529         llvm::append_range(payload, current);
2530       results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2531     } else {
2532       assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2533              "expected param type");
2534       ArrayRef<Attribute> current = state.getParams(handle);
2535       SmallVector<Attribute> params;
2536       params.reserve(numRepetitions * current.size());
2537       for (unsigned i = 0; i < numRepetitions; ++i)
2538         llvm::append_range(params, current);
2539       results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2540                         params);
2541     }
2542   }
2543   return DiagnosedSilenceableFailure::success();
2544 }
2545 
2546 void transform::ReplicateOp::getEffects(
2547     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2548   onlyReadsHandle(getPatternMutable(), effects);
2549   onlyReadsHandle(getHandlesMutable(), effects);
2550   producesHandle(getOperation()->getOpResults(), effects);
2551 }
2552 
2553 //===----------------------------------------------------------------------===//
2554 // SequenceOp
2555 //===----------------------------------------------------------------------===//
2556 
2557 DiagnosedSilenceableFailure
2558 transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2559                              transform::TransformResults &results,
2560                              transform::TransformState &state) {
2561   // Map the entry block argument to the list of operations.
2562   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2563   if (failed(mapBlockArguments(state)))
2564     return DiagnosedSilenceableFailure::definiteFailure();
2565 
2566   return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2567                             results);
2568 }
2569 
2570 static ParseResult parseSequenceOpOperands(
2571     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2572     Type &rootType,
2573     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2574     SmallVectorImpl<Type> &extraBindingTypes) {
2575   OpAsmParser::UnresolvedOperand rootOperand;
2576   OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2577   if (!hasRoot.has_value()) {
2578     root = std::nullopt;
2579     return success();
2580   }
2581   if (failed(hasRoot.value()))
2582     return failure();
2583   root = rootOperand;
2584 
2585   if (succeeded(parser.parseOptionalComma())) {
2586     if (failed(parser.parseOperandList(extraBindings)))
2587       return failure();
2588   }
2589   if (failed(parser.parseColon()))
2590     return failure();
2591 
2592   // The paren is truly optional.
2593   (void)parser.parseOptionalLParen();
2594 
2595   if (failed(parser.parseType(rootType))) {
2596     return failure();
2597   }
2598 
2599   if (!extraBindings.empty()) {
2600     if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2601       return failure();
2602   }
2603 
2604   if (extraBindingTypes.size() != extraBindings.size()) {
2605     return parser.emitError(parser.getNameLoc(),
2606                             "expected types to be provided for all operands");
2607   }
2608 
2609   // The paren is truly optional.
2610   (void)parser.parseOptionalRParen();
2611   return success();
2612 }
2613 
2614 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
2615                                     Value root, Type rootType,
2616                                     ValueRange extraBindings,
2617                                     TypeRange extraBindingTypes) {
2618   if (!root)
2619     return;
2620 
2621   printer << root;
2622   bool hasExtras = !extraBindings.empty();
2623   if (hasExtras) {
2624     printer << ", ";
2625     printer.printOperands(extraBindings);
2626   }
2627 
2628   printer << " : ";
2629   if (hasExtras)
2630     printer << "(";
2631 
2632   printer << rootType;
2633   if (hasExtras) {
2634     printer << ", ";
2635     llvm::interleaveComma(extraBindingTypes, printer.getStream());
2636     printer << ")";
2637   }
2638 }
2639 
2640 /// Returns `true` if the given op operand may be consuming the handle value in
2641 /// the Transform IR. That is, if it may have a Free effect on it.
2642 static bool isValueUsePotentialConsumer(OpOperand &use) {
2643   // Conservatively assume the effect being present in absence of the interface.
2644   auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2645   if (!iface)
2646     return true;
2647 
2648   return isHandleConsumed(use.get(), iface);
2649 }
2650 
2651 LogicalResult
2652 checkDoubleConsume(Value value,
2653                    function_ref<InFlightDiagnostic()> reportError) {
2654   OpOperand *potentialConsumer = nullptr;
2655   for (OpOperand &use : value.getUses()) {
2656     if (!isValueUsePotentialConsumer(use))
2657       continue;
2658 
2659     if (!potentialConsumer) {
2660       potentialConsumer = &use;
2661       continue;
2662     }
2663 
2664     InFlightDiagnostic diag = reportError()
2665                               << " has more than one potential consumer";
2666     diag.attachNote(potentialConsumer->getOwner()->getLoc())
2667         << "used here as operand #" << potentialConsumer->getOperandNumber();
2668     diag.attachNote(use.getOwner()->getLoc())
2669         << "used here as operand #" << use.getOperandNumber();
2670     return diag;
2671   }
2672 
2673   return success();
2674 }
2675 
2676 LogicalResult transform::SequenceOp::verify() {
2677   assert(getBodyBlock()->getNumArguments() >= 1 &&
2678          "the number of arguments must have been verified to be more than 1 by "
2679          "PossibleTopLevelTransformOpTrait");
2680 
2681   if (!getRoot() && !getExtraBindings().empty()) {
2682     return emitOpError()
2683            << "does not expect extra operands when used as top-level";
2684   }
2685 
2686   // Check if a block argument has more than one consuming use.
2687   for (BlockArgument arg : getBodyBlock()->getArguments()) {
2688     if (failed(checkDoubleConsume(arg, [this, arg]() {
2689           return (emitOpError() << "block argument #" << arg.getArgNumber());
2690         }))) {
2691       return failure();
2692     }
2693   }
2694 
2695   // Check properties of the nested operations they cannot check themselves.
2696   for (Operation &child : *getBodyBlock()) {
2697     if (!isa<TransformOpInterface>(child) &&
2698         &child != &getBodyBlock()->back()) {
2699       InFlightDiagnostic diag =
2700           emitOpError()
2701           << "expected children ops to implement TransformOpInterface";
2702       diag.attachNote(child.getLoc()) << "op without interface";
2703       return diag;
2704     }
2705 
2706     for (OpResult result : child.getResults()) {
2707       auto report = [&]() {
2708         return (child.emitError() << "result #" << result.getResultNumber());
2709       };
2710       if (failed(checkDoubleConsume(result, report)))
2711         return failure();
2712     }
2713   }
2714 
2715   if (!getBodyBlock()->mightHaveTerminator())
2716     return emitOpError() << "expects to have a terminator in the body";
2717 
2718   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2719       getOperation()->getResultTypes()) {
2720     InFlightDiagnostic diag = emitOpError()
2721                               << "expects the types of the terminator operands "
2722                                  "to match the types of the result";
2723     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2724     return diag;
2725   }
2726   return success();
2727 }
2728 
2729 void transform::SequenceOp::getEffects(
2730     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2731   getPotentialTopLevelEffects(effects);
2732 }
2733 
2734 OperandRange
2735 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2736   assert(point == getBody() && "unexpected region index");
2737   if (getOperation()->getNumOperands() > 0)
2738     return getOperation()->getOperands();
2739   return OperandRange(getOperation()->operand_end(),
2740                       getOperation()->operand_end());
2741 }
2742 
2743 void transform::SequenceOp::getSuccessorRegions(
2744     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2745   if (point.isParent()) {
2746     Region *bodyRegion = &getBody();
2747     regions.emplace_back(bodyRegion, getNumOperands() != 0
2748                                          ? bodyRegion->getArguments()
2749                                          : Block::BlockArgListType());
2750     return;
2751   }
2752 
2753   assert(point == getBody() && "unexpected region index");
2754   regions.emplace_back(getOperation()->getResults());
2755 }
2756 
2757 void transform::SequenceOp::getRegionInvocationBounds(
2758     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2759   (void)operands;
2760   bounds.emplace_back(1, 1);
2761 }
2762 
2763 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2764                                   TypeRange resultTypes,
2765                                   FailurePropagationMode failurePropagationMode,
2766                                   Value root,
2767                                   SequenceBodyBuilderFn bodyBuilder) {
2768   build(builder, state, resultTypes, failurePropagationMode, root,
2769         /*extra_bindings=*/ValueRange());
2770   Type bbArgType = root.getType();
2771   buildSequenceBody(builder, state, bbArgType,
2772                     /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2773 }
2774 
2775 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2776                                   TypeRange resultTypes,
2777                                   FailurePropagationMode failurePropagationMode,
2778                                   Value root, ValueRange extraBindings,
2779                                   SequenceBodyBuilderArgsFn bodyBuilder) {
2780   build(builder, state, resultTypes, failurePropagationMode, root,
2781         extraBindings);
2782   buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
2783                     bodyBuilder);
2784 }
2785 
2786 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2787                                   TypeRange resultTypes,
2788                                   FailurePropagationMode failurePropagationMode,
2789                                   Type bbArgType,
2790                                   SequenceBodyBuilderFn bodyBuilder) {
2791   build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2792         /*extra_bindings=*/ValueRange());
2793   buildSequenceBody(builder, state, bbArgType,
2794                     /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2795 }
2796 
2797 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2798                                   TypeRange resultTypes,
2799                                   FailurePropagationMode failurePropagationMode,
2800                                   Type bbArgType, TypeRange extraBindingTypes,
2801                                   SequenceBodyBuilderArgsFn bodyBuilder) {
2802   build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2803         /*extra_bindings=*/ValueRange());
2804   buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
2805 }
2806 
2807 //===----------------------------------------------------------------------===//
2808 // PrintOp
2809 //===----------------------------------------------------------------------===//
2810 
2811 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2812                                StringRef name) {
2813   if (!name.empty())
2814     result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
2815 }
2816 
2817 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2818                                Value target, StringRef name) {
2819   result.addOperands({target});
2820   build(builder, result, name);
2821 }
2822 
2823 DiagnosedSilenceableFailure
2824 transform::PrintOp::apply(transform::TransformRewriter &rewriter,
2825                           transform::TransformResults &results,
2826                           transform::TransformState &state) {
2827   llvm::outs() << "[[[ IR printer: ";
2828   if (getName().has_value())
2829     llvm::outs() << *getName() << " ";
2830 
2831   OpPrintingFlags printFlags;
2832   if (getAssumeVerified().value_or(false))
2833     printFlags.assumeVerified();
2834   if (getUseLocalScope().value_or(false))
2835     printFlags.useLocalScope();
2836   if (getSkipRegions().value_or(false))
2837     printFlags.skipRegions();
2838 
2839   if (!getTarget()) {
2840     llvm::outs() << "top-level ]]]\n";
2841     state.getTopLevel()->print(llvm::outs(), printFlags);
2842     llvm::outs() << "\n";
2843     llvm::outs().flush();
2844     return DiagnosedSilenceableFailure::success();
2845   }
2846 
2847   llvm::outs() << "]]]\n";
2848   for (Operation *target : state.getPayloadOps(getTarget())) {
2849     target->print(llvm::outs(), printFlags);
2850     llvm::outs() << "\n";
2851   }
2852 
2853   llvm::outs().flush();
2854   return DiagnosedSilenceableFailure::success();
2855 }
2856 
2857 void transform::PrintOp::getEffects(
2858     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2859   // We don't really care about mutability here, but `getTarget` now
2860   // unconditionally casts to a specific type before verification could run
2861   // here.
2862   if (!getTargetMutable().empty())
2863     onlyReadsHandle(getTargetMutable()[0], effects);
2864   onlyReadsPayload(effects);
2865 
2866   // There is no resource for stderr file descriptor, so just declare print
2867   // writes into the default resource.
2868   effects.emplace_back(MemoryEffects::Write::get());
2869 }
2870 
2871 //===----------------------------------------------------------------------===//
2872 // VerifyOp
2873 //===----------------------------------------------------------------------===//
2874 
2875 DiagnosedSilenceableFailure
2876 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
2877                                 Operation *target,
2878                                 transform::ApplyToEachResultList &results,
2879                                 transform::TransformState &state) {
2880   if (failed(::mlir::verify(target))) {
2881     DiagnosedDefiniteFailure diag = emitDefiniteFailure()
2882                                     << "failed to verify payload op";
2883     diag.attachNote(target->getLoc()) << "payload op";
2884     return diag;
2885   }
2886   return DiagnosedSilenceableFailure::success();
2887 }
2888 
2889 void transform::VerifyOp::getEffects(
2890     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2891   transform::onlyReadsHandle(getTargetMutable(), effects);
2892 }
2893 
2894 //===----------------------------------------------------------------------===//
2895 // YieldOp
2896 //===----------------------------------------------------------------------===//
2897 
2898 void transform::YieldOp::getEffects(
2899     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2900   onlyReadsHandle(getOperandsMutable(), effects);
2901 }
2902