xref: /llvm-project/mlir/lib/Dialect/Transform/IR/TransformOps.cpp (revision 5caab8bbc0f89f46aca07be2090c8d23c78605ba)
1 //===- TransformOps.cpp - Transform dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
15 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
16 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
18 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dominance.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/Verifier.h"
23 #include "mlir/Interfaces/ControlFlowInterfaces.h"
24 #include "mlir/Interfaces/FunctionImplementation.h"
25 #include "mlir/Interfaces/FunctionInterfaces.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Pass/PassManager.h"
28 #include "mlir/Pass/PassRegistry.h"
29 #include "mlir/Transforms/CSE.h"
30 #include "mlir/Transforms/DialectConversion.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/ScopeExit.h"
35 #include "llvm/ADT/SmallPtrSet.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/Debug.h"
38 #include <optional>
39 
40 #define DEBUG_TYPE "transform-dialect"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
42 
43 #define DEBUG_TYPE_MATCHER "transform-matcher"
44 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
45 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
46 
47 using namespace mlir;
48 
49 static ParseResult parseSequenceOpOperands(
50     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
51     Type &rootType,
52     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
53     SmallVectorImpl<Type> &extraBindingTypes);
54 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
55                                     Value root, Type rootType,
56                                     ValueRange extraBindings,
57                                     TypeRange extraBindingTypes);
58 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
59                                      ArrayAttr matchers, ArrayAttr actions);
60 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
61                                             ArrayAttr &matchers,
62                                             ArrayAttr &actions);
63 
64 /// Helper function to check if the given transform op is contained in (or
65 /// equal to) the given payload target op. In that case, an error is returned.
66 /// Transforming transform IR that is currently executing is generally unsafe.
67 static DiagnosedSilenceableFailure
68 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
69                                      Operation *payload) {
70   Operation *transformAncestor = transform.getOperation();
71   while (transformAncestor) {
72     if (transformAncestor == payload) {
73       DiagnosedDefiniteFailure diag =
74           transform.emitDefiniteFailure()
75           << "cannot apply transform to itself (or one of its ancestors)";
76       diag.attachNote(payload->getLoc()) << "target payload op";
77       return diag;
78     }
79     transformAncestor = transformAncestor->getParentOp();
80   }
81   return DiagnosedSilenceableFailure::success();
82 }
83 
84 #define GET_OP_CLASSES
85 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
86 
87 //===----------------------------------------------------------------------===//
88 // AlternativesOp
89 //===----------------------------------------------------------------------===//
90 
91 OperandRange
92 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
93   if (!point.isParent() && getOperation()->getNumOperands() == 1)
94     return getOperation()->getOperands();
95   return OperandRange(getOperation()->operand_end(),
96                       getOperation()->operand_end());
97 }
98 
99 void transform::AlternativesOp::getSuccessorRegions(
100     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
101   for (Region &alternative : llvm::drop_begin(
102            getAlternatives(),
103            point.isParent() ? 0
104                             : point.getRegionOrNull()->getRegionNumber() + 1)) {
105     regions.emplace_back(&alternative, !getOperands().empty()
106                                            ? alternative.getArguments()
107                                            : Block::BlockArgListType());
108   }
109   if (!point.isParent())
110     regions.emplace_back(getOperation()->getResults());
111 }
112 
113 void transform::AlternativesOp::getRegionInvocationBounds(
114     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
115   (void)operands;
116   // The region corresponding to the first alternative is always executed, the
117   // remaining may or may not be executed.
118   bounds.reserve(getNumRegions());
119   bounds.emplace_back(1, 1);
120   bounds.resize(getNumRegions(), InvocationBounds(0, 1));
121 }
122 
123 static void forwardEmptyOperands(Block *block, transform::TransformState &state,
124                                  transform::TransformResults &results) {
125   for (const auto &res : block->getParentOp()->getOpResults())
126     results.set(res, {});
127 }
128 
129 DiagnosedSilenceableFailure
130 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
131                                  transform::TransformResults &results,
132                                  transform::TransformState &state) {
133   SmallVector<Operation *> originals;
134   if (Value scopeHandle = getScope())
135     llvm::append_range(originals, state.getPayloadOps(scopeHandle));
136   else
137     originals.push_back(state.getTopLevel());
138 
139   for (Operation *original : originals) {
140     if (original->isAncestor(getOperation())) {
141       auto diag = emitDefiniteFailure()
142                   << "scope must not contain the transforms being applied";
143       diag.attachNote(original->getLoc()) << "scope";
144       return diag;
145     }
146     if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
147       auto diag = emitDefiniteFailure()
148                   << "only isolated-from-above ops can be alternative scopes";
149       diag.attachNote(original->getLoc()) << "scope";
150       return diag;
151     }
152   }
153 
154   for (Region &reg : getAlternatives()) {
155     // Clone the scope operations and make the transforms in this alternative
156     // region apply to them by virtue of mapping the block argument (the only
157     // visible handle) to the cloned scope operations. This effectively prevents
158     // the transformation from accessing any IR outside the scope.
159     auto scope = state.make_region_scope(reg);
160     auto clones = llvm::to_vector(
161         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
162     auto deleteClones = llvm::make_scope_exit([&] {
163       for (Operation *clone : clones)
164         clone->erase();
165     });
166     if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
167       return DiagnosedSilenceableFailure::definiteFailure();
168 
169     bool failed = false;
170     for (Operation &transform : reg.front().without_terminator()) {
171       DiagnosedSilenceableFailure result =
172           state.applyTransform(cast<TransformOpInterface>(transform));
173       if (result.isSilenceableFailure()) {
174         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
175                           << "\n");
176         failed = true;
177         break;
178       }
179 
180       if (::mlir::failed(result.silence()))
181         return DiagnosedSilenceableFailure::definiteFailure();
182     }
183 
184     // If all operations in the given alternative succeeded, no need to consider
185     // the rest. Replace the original scoping operation with the clone on which
186     // the transformations were performed.
187     if (!failed) {
188       // We will be using the clones, so cancel their scheduled deletion.
189       deleteClones.release();
190       TrackingListener listener(state, *this);
191       IRRewriter rewriter(getContext(), &listener);
192       for (const auto &kvp : llvm::zip(originals, clones)) {
193         Operation *original = std::get<0>(kvp);
194         Operation *clone = std::get<1>(kvp);
195         original->getBlock()->getOperations().insert(original->getIterator(),
196                                                      clone);
197         rewriter.replaceOp(original, clone->getResults());
198       }
199       detail::forwardTerminatorOperands(&reg.front(), state, results);
200       return DiagnosedSilenceableFailure::success();
201     }
202   }
203   return emitSilenceableError() << "all alternatives failed";
204 }
205 
206 void transform::AlternativesOp::getEffects(
207     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
208   consumesHandle(getOperands(), effects);
209   producesHandle(getResults(), effects);
210   for (Region *region : getRegions()) {
211     if (!region->empty())
212       producesHandle(region->front().getArguments(), effects);
213   }
214   modifiesPayload(effects);
215 }
216 
217 LogicalResult transform::AlternativesOp::verify() {
218   for (Region &alternative : getAlternatives()) {
219     Block &block = alternative.front();
220     Operation *terminator = block.getTerminator();
221     if (terminator->getOperands().getTypes() != getResults().getTypes()) {
222       InFlightDiagnostic diag = emitOpError()
223                                 << "expects terminator operands to have the "
224                                    "same type as results of the operation";
225       diag.attachNote(terminator->getLoc()) << "terminator";
226       return diag;
227     }
228   }
229 
230   return success();
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // AnnotateOp
235 //===----------------------------------------------------------------------===//
236 
237 DiagnosedSilenceableFailure
238 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
239                              transform::TransformResults &results,
240                              transform::TransformState &state) {
241   SmallVector<Operation *> targets =
242       llvm::to_vector(state.getPayloadOps(getTarget()));
243 
244   Attribute attr = UnitAttr::get(getContext());
245   if (auto paramH = getParam()) {
246     ArrayRef<Attribute> params = state.getParams(paramH);
247     if (params.size() != 1) {
248       if (targets.size() != params.size()) {
249         return emitSilenceableError()
250                << "parameter and target have different payload lengths ("
251                << params.size() << " vs " << targets.size() << ")";
252       }
253       for (auto &&[target, attr] : llvm::zip_equal(targets, params))
254         target->setAttr(getName(), attr);
255       return DiagnosedSilenceableFailure::success();
256     }
257     attr = params[0];
258   }
259   for (auto target : targets)
260     target->setAttr(getName(), attr);
261   return DiagnosedSilenceableFailure::success();
262 }
263 
264 void transform::AnnotateOp::getEffects(
265     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
266   onlyReadsHandle(getTarget(), effects);
267   onlyReadsHandle(getParam(), effects);
268   modifiesPayload(effects);
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // ApplyCommonSubexpressionEliminationOp
273 //===----------------------------------------------------------------------===//
274 
275 DiagnosedSilenceableFailure
276 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
277     transform::TransformRewriter &rewriter, Operation *target,
278     ApplyToEachResultList &results, transform::TransformState &state) {
279   // Make sure that this transform is not applied to itself. Modifying the
280   // transform IR while it is being interpreted is generally dangerous.
281   DiagnosedSilenceableFailure payloadCheck =
282       ensurePayloadIsSeparateFromTransform(*this, target);
283   if (!payloadCheck.succeeded())
284     return payloadCheck;
285 
286   DominanceInfo domInfo;
287   mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
288   return DiagnosedSilenceableFailure::success();
289 }
290 
291 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
292     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
293   transform::onlyReadsHandle(getTarget(), effects);
294   transform::modifiesPayload(effects);
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // ApplyDeadCodeEliminationOp
299 //===----------------------------------------------------------------------===//
300 
301 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
302     transform::TransformRewriter &rewriter, Operation *target,
303     ApplyToEachResultList &results, transform::TransformState &state) {
304   // Make sure that this transform is not applied to itself. Modifying the
305   // transform IR while it is being interpreted is generally dangerous.
306   DiagnosedSilenceableFailure payloadCheck =
307       ensurePayloadIsSeparateFromTransform(*this, target);
308   if (!payloadCheck.succeeded())
309     return payloadCheck;
310 
311   // Maintain a worklist of potentially dead ops.
312   SetVector<Operation *> worklist;
313 
314   // Helper function that adds all defining ops of used values (operands and
315   // operands of nested ops).
316   auto addDefiningOpsToWorklist = [&](Operation *op) {
317     op->walk([&](Operation *op) {
318       for (Value v : op->getOperands())
319         if (Operation *defOp = v.getDefiningOp())
320           if (target->isProperAncestor(defOp))
321             worklist.insert(defOp);
322     });
323   };
324 
325   // Helper function that erases an op.
326   auto eraseOp = [&](Operation *op) {
327     // Remove op and nested ops from the worklist.
328     op->walk([&](Operation *op) {
329       auto it = llvm::find(worklist, op);
330       if (it != worklist.end())
331         worklist.erase(it);
332     });
333     rewriter.eraseOp(op);
334   };
335 
336   // Initial walk over the IR.
337   target->walk<WalkOrder::PostOrder>([&](Operation *op) {
338     if (op != target && isOpTriviallyDead(op)) {
339       addDefiningOpsToWorklist(op);
340       eraseOp(op);
341     }
342   });
343 
344   // Erase all ops that have become dead.
345   while (!worklist.empty()) {
346     Operation *op = worklist.pop_back_val();
347     if (!isOpTriviallyDead(op))
348       continue;
349     addDefiningOpsToWorklist(op);
350     eraseOp(op);
351   }
352 
353   return DiagnosedSilenceableFailure::success();
354 }
355 
356 void transform::ApplyDeadCodeEliminationOp::getEffects(
357     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
358   transform::onlyReadsHandle(getTarget(), effects);
359   transform::modifiesPayload(effects);
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // ApplyPatternsOp
364 //===----------------------------------------------------------------------===//
365 
366 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
367     transform::TransformRewriter &rewriter, Operation *target,
368     ApplyToEachResultList &results, transform::TransformState &state) {
369   // Make sure that this transform is not applied to itself. Modifying the
370   // transform IR while it is being interpreted is generally dangerous. Even
371   // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
372   // performs many additional simplifications such as dead code elimination.
373   DiagnosedSilenceableFailure payloadCheck =
374       ensurePayloadIsSeparateFromTransform(*this, target);
375   if (!payloadCheck.succeeded())
376     return payloadCheck;
377 
378   // Gather all specified patterns.
379   MLIRContext *ctx = target->getContext();
380   RewritePatternSet patterns(ctx);
381   if (!getRegion().empty()) {
382     for (Operation &op : getRegion().front()) {
383       cast<transform::PatternDescriptorOpInterface>(&op)
384           .populatePatternsWithState(patterns, state);
385     }
386   }
387 
388   // Configure the GreedyPatternRewriteDriver.
389   GreedyRewriteConfig config;
390   config.listener =
391       static_cast<RewriterBase::Listener *>(rewriter.getListener());
392   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
393 
394   // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
395   // was requested, apply the greedy pattern rewrite only once. (The greedy
396   // pattern rewrite driver already iterates to a fixpoint internally.)
397   bool cseChanged = false;
398   // One or two iterations should be sufficient. Stop iterating after a certain
399   // threshold to make debugging easier.
400   static const int64_t kNumMaxIterations = 50;
401   int64_t iteration = 0;
402   do {
403     LogicalResult result = failure();
404     if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
405       // Op is isolated from above. Apply patterns and also perform region
406       // simplification.
407       result = applyPatternsAndFoldGreedily(target, frozenPatterns, config);
408     } else {
409       // Manually gather list of ops because the other
410       // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
411       // from above. This way, patterns can be applied to ops that are not
412       // isolated from above. Regions are not being simplified. Furthermore,
413       // only a single greedy rewrite iteration is performed.
414       SmallVector<Operation *> ops;
415       target->walk([&](Operation *nestedOp) {
416         if (target != nestedOp)
417           ops.push_back(nestedOp);
418       });
419       result = applyOpPatternsAndFold(ops, frozenPatterns, config);
420     }
421 
422     // A failure typically indicates that the pattern application did not
423     // converge.
424     if (failed(result)) {
425       return emitSilenceableFailure(target)
426              << "greedy pattern application failed";
427     }
428 
429     if (getApplyCse()) {
430       DominanceInfo domInfo;
431       mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
432                                           &cseChanged);
433     }
434   } while (cseChanged && ++iteration < kNumMaxIterations);
435 
436   if (iteration == kNumMaxIterations)
437     return emitDefiniteFailure() << "fixpoint iteration did not converge";
438 
439   return DiagnosedSilenceableFailure::success();
440 }
441 
442 LogicalResult transform::ApplyPatternsOp::verify() {
443   if (!getRegion().empty()) {
444     for (Operation &op : getRegion().front()) {
445       if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
446         InFlightDiagnostic diag = emitOpError()
447                                   << "expected children ops to implement "
448                                      "PatternDescriptorOpInterface";
449         diag.attachNote(op.getLoc()) << "op without interface";
450         return diag;
451       }
452     }
453   }
454   return success();
455 }
456 
457 void transform::ApplyPatternsOp::getEffects(
458     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
459   transform::onlyReadsHandle(getTarget(), effects);
460   transform::modifiesPayload(effects);
461 }
462 
463 void transform::ApplyPatternsOp::build(
464     OpBuilder &builder, OperationState &result, Value target,
465     function_ref<void(OpBuilder &, Location)> bodyBuilder) {
466   result.addOperands(target);
467 
468   OpBuilder::InsertionGuard g(builder);
469   Region *region = result.addRegion();
470   builder.createBlock(region);
471   if (bodyBuilder)
472     bodyBuilder(builder, result.location);
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // ApplyCanonicalizationPatternsOp
477 //===----------------------------------------------------------------------===//
478 
479 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
480     RewritePatternSet &patterns) {
481   MLIRContext *ctx = patterns.getContext();
482   for (Dialect *dialect : ctx->getLoadedDialects())
483     dialect->getCanonicalizationPatterns(patterns);
484   for (RegisteredOperationName op : ctx->getRegisteredOperations())
485     op.getCanonicalizationPatterns(patterns, ctx);
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // ApplyConversionPatternsOp
490 //===----------------------------------------------------------------------===//
491 
492 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
493     transform::TransformRewriter &rewriter,
494     transform::TransformResults &results, transform::TransformState &state) {
495   MLIRContext *ctx = getContext();
496 
497   // Instantiate the default type converter if a type converter builder is
498   // specified.
499   std::unique_ptr<TypeConverter> defaultTypeConverter;
500   transform::TypeConverterBuilderOpInterface typeConverterBuilder =
501       getDefaultTypeConverter();
502   if (typeConverterBuilder)
503     defaultTypeConverter = typeConverterBuilder.getTypeConverter();
504 
505   // Configure conversion target.
506   ConversionTarget conversionTarget(*getContext());
507   if (getLegalOps())
508     for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
509       conversionTarget.addLegalOp(
510           OperationName(cast<StringAttr>(attr).getValue(), ctx));
511   if (getIllegalOps())
512     for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
513       conversionTarget.addIllegalOp(
514           OperationName(cast<StringAttr>(attr).getValue(), ctx));
515   if (getLegalDialects())
516     for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
517       conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
518   if (getIllegalDialects())
519     for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
520       conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
521 
522   // Gather all specified patterns.
523   RewritePatternSet patterns(ctx);
524   // Need to keep the converters alive until after pattern application because
525   // the patterns take a reference to an object that would otherwise get out of
526   // scope.
527   SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
528   if (!getPatterns().empty()) {
529     for (Operation &op : getPatterns().front()) {
530       auto descriptor =
531           cast<transform::ConversionPatternDescriptorOpInterface>(&op);
532 
533       // Check if this pattern set specifies a type converter.
534       std::unique_ptr<TypeConverter> typeConverter =
535           descriptor.getTypeConverter();
536       TypeConverter *converter = nullptr;
537       if (typeConverter) {
538         keepAliveConverters.emplace_back(std::move(typeConverter));
539         converter = keepAliveConverters.back().get();
540       } else {
541         // No type converter specified: Use the default type converter.
542         if (!defaultTypeConverter) {
543           auto diag = emitDefiniteFailure()
544                       << "pattern descriptor does not specify type "
545                          "converter and apply_conversion_patterns op has "
546                          "no default type converter";
547           diag.attachNote(op.getLoc()) << "pattern descriptor op";
548           return diag;
549         }
550         converter = defaultTypeConverter.get();
551       }
552 
553       // Add descriptor-specific updates to the conversion target, which may
554       // depend on the final type converter. In structural converters, the
555       // legality of types dictates the dynamic legality of an operation.
556       descriptor.populateConversionTargetRules(*converter, conversionTarget);
557 
558       descriptor.populatePatterns(*converter, patterns);
559     }
560   }
561 
562   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
563   for (Operation *target : state.getPayloadOps(getTarget())) {
564     // Make sure that this transform is not applied to itself. Modifying the
565     // transform IR while it is being interpreted is generally dangerous.
566     DiagnosedSilenceableFailure payloadCheck =
567         ensurePayloadIsSeparateFromTransform(*this, target);
568     if (!payloadCheck.succeeded())
569       return payloadCheck;
570 
571     LogicalResult status = failure();
572     if (getPartialConversion()) {
573       status = applyPartialConversion(target, conversionTarget, frozenPatterns);
574     } else {
575       status = applyFullConversion(target, conversionTarget, frozenPatterns);
576     }
577 
578     if (failed(status)) {
579       auto diag = emitSilenceableError() << "dialect conversion failed";
580       diag.attachNote(target->getLoc()) << "target op";
581       return diag;
582     }
583   }
584 
585   return DiagnosedSilenceableFailure::success();
586 }
587 
588 LogicalResult transform::ApplyConversionPatternsOp::verify() {
589   if (getNumRegions() != 1 && getNumRegions() != 2)
590     return emitOpError() << "expected 1 or 2 regions";
591   if (!getPatterns().empty()) {
592     for (Operation &op : getPatterns().front()) {
593       if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
594         InFlightDiagnostic diag =
595             emitOpError() << "expected pattern children ops to implement "
596                              "ConversionPatternDescriptorOpInterface";
597         diag.attachNote(op.getLoc()) << "op without interface";
598         return diag;
599       }
600     }
601   }
602   if (getNumRegions() == 2) {
603     Region &typeConverterRegion = getRegion(1);
604     if (!llvm::hasSingleElement(typeConverterRegion.front()))
605       return emitOpError()
606              << "expected exactly one op in default type converter region";
607     auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
608         &typeConverterRegion.front().front());
609     if (!typeConverterOp) {
610       InFlightDiagnostic diag = emitOpError()
611                                 << "expected default converter child op to "
612                                    "implement TypeConverterBuilderOpInterface";
613       diag.attachNote(typeConverterOp->getLoc()) << "op without interface";
614       return diag;
615     }
616     // Check default type converter type.
617     if (!getPatterns().empty()) {
618       for (Operation &op : getPatterns().front()) {
619         auto descriptor =
620             cast<transform::ConversionPatternDescriptorOpInterface>(&op);
621         if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
622           return failure();
623       }
624     }
625   }
626   return success();
627 }
628 
629 void transform::ApplyConversionPatternsOp::getEffects(
630     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
631   transform::consumesHandle(getTarget(), effects);
632   transform::modifiesPayload(effects);
633 }
634 
635 void transform::ApplyConversionPatternsOp::build(
636     OpBuilder &builder, OperationState &result, Value target,
637     function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
638     function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
639   result.addOperands(target);
640 
641   {
642     OpBuilder::InsertionGuard g(builder);
643     Region *region1 = result.addRegion();
644     builder.createBlock(region1);
645     if (patternsBodyBuilder)
646       patternsBodyBuilder(builder, result.location);
647   }
648   {
649     OpBuilder::InsertionGuard g(builder);
650     Region *region2 = result.addRegion();
651     builder.createBlock(region2);
652     if (typeConverterBodyBuilder)
653       typeConverterBodyBuilder(builder, result.location);
654   }
655 }
656 
657 //===----------------------------------------------------------------------===//
658 // ApplyToLLVMConversionPatternsOp
659 //===----------------------------------------------------------------------===//
660 
661 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
662     TypeConverter &typeConverter, RewritePatternSet &patterns) {
663   Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
664   assert(dialect && "expected that dialect is loaded");
665   auto iface = cast<ConvertToLLVMPatternInterface>(dialect);
666   // ConversionTarget is currently ignored because the enclosing
667   // apply_conversion_patterns op sets up its own ConversionTarget.
668   ConversionTarget target(*getContext());
669   iface->populateConvertToLLVMConversionPatterns(
670       target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
671 }
672 
673 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
674     transform::TypeConverterBuilderOpInterface builder) {
675   if (builder.getTypeConverterType() != "LLVMTypeConverter")
676     return emitOpError("expected LLVMTypeConverter");
677   return success();
678 }
679 
680 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
681   Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
682   if (!dialect)
683     return emitOpError("unknown dialect or dialect not loaded: ")
684            << getDialectName();
685   auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
686   if (!iface)
687     return emitOpError(
688                "dialect does not implement ConvertToLLVMPatternInterface or "
689                "extension was not loaded: ")
690            << getDialectName();
691   return success();
692 }
693 
694 //===----------------------------------------------------------------------===//
695 // ApplyLoopInvariantCodeMotionOp
696 //===----------------------------------------------------------------------===//
697 
698 DiagnosedSilenceableFailure
699 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
700     transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
701     transform::ApplyToEachResultList &results,
702     transform::TransformState &state) {
703   // Currently, LICM does not remove operations, so we don't need tracking.
704   // If this ever changes, add a LICM entry point that takes a rewriter.
705   moveLoopInvariantCode(target);
706   return DiagnosedSilenceableFailure::success();
707 }
708 
709 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
710     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
711   transform::onlyReadsHandle(getTarget(), effects);
712   transform::modifiesPayload(effects);
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // ApplyRegisteredPassOp
717 //===----------------------------------------------------------------------===//
718 
719 DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
720     transform::TransformRewriter &rewriter, Operation *target,
721     ApplyToEachResultList &results, transform::TransformState &state) {
722   // Make sure that this transform is not applied to itself. Modifying the
723   // transform IR while it is being interpreted is generally dangerous. Even
724   // more so when applying passes because they may perform a wide range of IR
725   // modifications.
726   DiagnosedSilenceableFailure payloadCheck =
727       ensurePayloadIsSeparateFromTransform(*this, target);
728   if (!payloadCheck.succeeded())
729     return payloadCheck;
730 
731   // Get pass or pass pipeline from registry.
732   const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
733   if (!info)
734     info = PassInfo::lookup(getPassName());
735   if (!info)
736     return emitDefiniteFailure()
737            << "unknown pass or pass pipeline: " << getPassName();
738 
739   // Create pass manager and run the pass or pass pipeline.
740   PassManager pm(getContext());
741   if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
742         emitError(msg);
743         return failure();
744       }))) {
745     return emitDefiniteFailure()
746            << "failed to add pass or pass pipeline to pipeline: "
747            << getPassName();
748   }
749   if (failed(pm.run(target))) {
750     auto diag = emitSilenceableError() << "pass pipeline failed";
751     diag.attachNote(target->getLoc()) << "target op";
752     return diag;
753   }
754 
755   results.push_back(target);
756   return DiagnosedSilenceableFailure::success();
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // CastOp
761 //===----------------------------------------------------------------------===//
762 
763 DiagnosedSilenceableFailure
764 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
765                               Operation *target, ApplyToEachResultList &results,
766                               transform::TransformState &state) {
767   results.push_back(target);
768   return DiagnosedSilenceableFailure::success();
769 }
770 
771 void transform::CastOp::getEffects(
772     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
773   onlyReadsPayload(effects);
774   onlyReadsHandle(getInput(), effects);
775   producesHandle(getOutput(), effects);
776 }
777 
778 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
779   assert(inputs.size() == 1 && "expected one input");
780   assert(outputs.size() == 1 && "expected one output");
781   return llvm::all_of(
782       std::initializer_list<Type>{inputs.front(), outputs.front()},
783       [](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); });
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // CollectMatchingOp
788 //===----------------------------------------------------------------------===//
789 
790 /// Applies matcher operations from the given `block` assigning `op` as the
791 /// payload of the block's first argument. Updates `state` accordingly. If any
792 /// of the matcher produces a silenceable failure, discards it (printing the
793 /// content to the debug output stream) and returns failure. If any of the
794 /// matchers produces a definite failure, reports it and returns failure. If all
795 /// matchers in the block succeed, populates `mappings` with the payload
796 /// entities associated with the block terminator operands.
797 static DiagnosedSilenceableFailure
798 matchBlock(Block &block, Operation *op, transform::TransformState &state,
799            SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
800   assert(block.getParent() && "cannot match using a detached block");
801   auto matchScope = state.make_region_scope(*block.getParent());
802   if (failed(state.mapBlockArgument(block.getArgument(0), {op})))
803     return DiagnosedSilenceableFailure::definiteFailure();
804 
805   for (Operation &match : block.without_terminator()) {
806     if (!isa<transform::MatchOpInterface>(match)) {
807       return emitDefiniteFailure(match.getLoc())
808              << "expected operations in the match part to "
809                 "implement MatchOpInterface";
810     }
811     DiagnosedSilenceableFailure diag =
812         state.applyTransform(cast<transform::TransformOpInterface>(match));
813     if (diag.succeeded())
814       continue;
815 
816     return diag;
817   }
818 
819   // Remember the values mapped to the terminator operands so we can
820   // forward them to the action.
821   ValueRange yieldedValues = block.getTerminator()->getOperands();
822   transform::detail::prepareValueMappings(mappings, yieldedValues, state);
823   return DiagnosedSilenceableFailure::success();
824 }
825 
826 /// Returns `true` if both types implement one of the interfaces provided as
827 /// template parameters.
828 template <typename... Tys>
829 static bool implementSameInterface(Type t1, Type t2) {
830   return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
831 }
832 
833 /// Returns `true` if both types implement one of the transform dialect
834 /// interfaces.
835 static bool implementSameTransformInterface(Type t1, Type t2) {
836   return implementSameInterface<transform::TransformHandleTypeInterface,
837                                 transform::TransformParamTypeInterface,
838                                 transform::TransformValueHandleTypeInterface>(
839       t1, t2);
840 }
841 
842 //===----------------------------------------------------------------------===//
843 // CollectMatchingOp
844 //===----------------------------------------------------------------------===//
845 
846 DiagnosedSilenceableFailure
847 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
848                                     transform::TransformResults &results,
849                                     transform::TransformState &state) {
850   auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
851       getOperation(), getMatcher());
852   if (matcher.isExternal()) {
853     return emitDefiniteFailure()
854            << "unresolved external symbol " << getMatcher();
855   }
856 
857   SmallVector<SmallVector<MappedValue>, 2> rawResults;
858   rawResults.resize(getOperation()->getNumResults());
859   std::optional<DiagnosedSilenceableFailure> maybeFailure;
860   for (Operation *root : state.getPayloadOps(getRoot())) {
861     WalkResult walkResult = root->walk([&](Operation *op) {
862       DEBUG_MATCHER({
863         DBGS_MATCHER() << "matching ";
864         op->print(llvm::dbgs(),
865                   OpPrintingFlags().assumeVerified().skipRegions());
866         llvm::dbgs() << " @" << op << "\n";
867       });
868 
869       // Try matching.
870       SmallVector<SmallVector<MappedValue>> mappings;
871       DiagnosedSilenceableFailure diag =
872           matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
873       if (diag.isDefiniteFailure())
874         return WalkResult::interrupt();
875       if (diag.isSilenceableFailure()) {
876         DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
877                                      << " failed: " << diag.getMessage());
878         return WalkResult::advance();
879       }
880 
881       // If succeeded, collect results.
882       for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
883         if (mapping.size() != 1) {
884           maybeFailure.emplace(emitSilenceableError()
885                                << "result #" << i << ", associated with "
886                                << mapping.size()
887                                << " payload objects, expected 1");
888           return WalkResult::interrupt();
889         }
890         rawResults[i].push_back(mapping[0]);
891       }
892       return WalkResult::advance();
893     });
894     if (walkResult.wasInterrupted())
895       return std::move(*maybeFailure);
896     assert(!maybeFailure && "failure set but the walk was not interrupted");
897 
898     for (auto &&[opResult, rawResult] :
899          llvm::zip_equal(getOperation()->getResults(), rawResults)) {
900       results.setMappedValues(opResult, rawResult);
901     }
902   }
903   return DiagnosedSilenceableFailure::success();
904 }
905 
906 void transform::CollectMatchingOp::getEffects(
907     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
908   onlyReadsHandle(getRoot(), effects);
909   producesHandle(getResults(), effects);
910   onlyReadsPayload(effects);
911 }
912 
913 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
914     SymbolTableCollection &symbolTable) {
915   auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
916       symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
917   if (!matcherSymbol ||
918       !isa<TransformOpInterface>(matcherSymbol.getOperation()))
919     return emitError() << "unresolved matcher symbol " << getMatcher();
920 
921   ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
922   if (argumentTypes.size() != 1 ||
923       !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
924     return emitError()
925            << "expected the matcher to take one operation handle argument";
926   }
927   if (!matcherSymbol.getArgAttr(
928           0, transform::TransformDialect::kArgReadOnlyAttrName)) {
929     return emitError() << "expected the matcher argument to be marked readonly";
930   }
931 
932   ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
933   if (resultTypes.size() != getOperation()->getNumResults()) {
934     return emitError()
935            << "expected the matcher to yield as many values as op has results ("
936            << getOperation()->getNumResults() << "), got "
937            << resultTypes.size();
938   }
939 
940   for (auto &&[i, matcherType, resultType] :
941        llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
942     if (implementSameTransformInterface(matcherType, resultType))
943       continue;
944 
945     return emitError()
946            << "mismatching type interfaces for matcher result and op result #"
947            << i;
948   }
949 
950   return success();
951 }
952 
953 //===----------------------------------------------------------------------===//
954 // ForeachMatchOp
955 //===----------------------------------------------------------------------===//
956 
957 DiagnosedSilenceableFailure
958 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
959                                  transform::TransformResults &results,
960                                  transform::TransformState &state) {
961   SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>>
962       matchActionPairs;
963   matchActionPairs.reserve(getMatchers().size());
964   SymbolTableCollection symbolTable;
965   for (auto &&[matcher, action] :
966        llvm::zip_equal(getMatchers(), getActions())) {
967     auto matcherSymbol =
968         symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
969             getOperation(), cast<SymbolRefAttr>(matcher));
970     auto actionSymbol =
971         symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
972             getOperation(), cast<SymbolRefAttr>(action));
973     assert(matcherSymbol && actionSymbol &&
974            "unresolved symbols not caught by the verifier");
975 
976     if (matcherSymbol.isExternal())
977       return emitDefiniteFailure() << "unresolved external symbol " << matcher;
978     if (actionSymbol.isExternal())
979       return emitDefiniteFailure() << "unresolved external symbol " << action;
980 
981     matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
982   }
983 
984   for (Operation *root : state.getPayloadOps(getRoot())) {
985     WalkResult walkResult = root->walk([&](Operation *op) {
986       // If getRestrictRoot is not present, skip over the root op itself so we
987       // don't invalidate it.
988       if (!getRestrictRoot() && op == root)
989         return WalkResult::advance();
990 
991       DEBUG_MATCHER({
992         DBGS_MATCHER() << "matching ";
993         op->print(llvm::dbgs(),
994                   OpPrintingFlags().assumeVerified().skipRegions());
995         llvm::dbgs() << " @" << op << "\n";
996       });
997 
998       // Try all the match/action pairs until the first successful match.
999       for (auto [matcher, action] : matchActionPairs) {
1000         SmallVector<SmallVector<MappedValue>> mappings;
1001         DiagnosedSilenceableFailure diag =
1002             matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
1003         if (diag.isDefiniteFailure())
1004           return WalkResult::interrupt();
1005         if (diag.isSilenceableFailure()) {
1006           DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1007                                        << " failed: " << diag.getMessage());
1008           continue;
1009         }
1010 
1011         auto scope = state.make_region_scope(action.getFunctionBody());
1012         for (auto &&[arg, map] : llvm::zip_equal(
1013                  action.getFunctionBody().front().getArguments(), mappings)) {
1014           if (failed(state.mapBlockArgument(arg, map)))
1015             return WalkResult::interrupt();
1016         }
1017 
1018         for (Operation &transform :
1019              action.getFunctionBody().front().without_terminator()) {
1020           DiagnosedSilenceableFailure result =
1021               state.applyTransform(cast<TransformOpInterface>(transform));
1022           if (failed(result.checkAndReport()))
1023             return WalkResult::interrupt();
1024         }
1025         break;
1026       }
1027       return WalkResult::advance();
1028     });
1029     if (walkResult.wasInterrupted())
1030       return DiagnosedSilenceableFailure::definiteFailure();
1031   }
1032 
1033   // The root operation should not have been affected, so we can just reassign
1034   // the payload to the result. Note that we need to consume the root handle to
1035   // make sure any handles to operations inside, that could have been affected
1036   // by actions, are invalidated.
1037   results.set(llvm::cast<OpResult>(getUpdated()),
1038               state.getPayloadOps(getRoot()));
1039   return DiagnosedSilenceableFailure::success();
1040 }
1041 
1042 void transform::ForeachMatchOp::getEffects(
1043     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1044   // Bail if invalid.
1045   if (getOperation()->getNumOperands() < 1 ||
1046       getOperation()->getNumResults() < 1) {
1047     return modifiesPayload(effects);
1048   }
1049 
1050   consumesHandle(getRoot(), effects);
1051   producesHandle(getUpdated(), effects);
1052   modifiesPayload(effects);
1053 }
1054 
1055 /// Parses the comma-separated list of symbol reference pairs of the format
1056 /// `@matcher -> @action`.
1057 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1058                                             ArrayAttr &matchers,
1059                                             ArrayAttr &actions) {
1060   StringAttr matcher;
1061   StringAttr action;
1062   SmallVector<Attribute> matcherList;
1063   SmallVector<Attribute> actionList;
1064   do {
1065     if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1066         parser.parseSymbolName(action)) {
1067       return failure();
1068     }
1069     matcherList.push_back(SymbolRefAttr::get(matcher));
1070     actionList.push_back(SymbolRefAttr::get(action));
1071   } while (parser.parseOptionalComma().succeeded());
1072 
1073   matchers = parser.getBuilder().getArrayAttr(matcherList);
1074   actions = parser.getBuilder().getArrayAttr(actionList);
1075   return success();
1076 }
1077 
1078 /// Prints the comma-separated list of symbol reference pairs of the format
1079 /// `@matcher -> @action`.
1080 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
1081                                      ArrayAttr matchers, ArrayAttr actions) {
1082   printer.increaseIndent();
1083   printer.increaseIndent();
1084   for (auto &&[matcher, action, idx] : llvm::zip_equal(
1085            matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1086     printer.printNewline();
1087     printer << cast<SymbolRefAttr>(matcher) << " -> "
1088             << cast<SymbolRefAttr>(action);
1089     if (idx != matchers.size() - 1)
1090       printer << ", ";
1091   }
1092   printer.decreaseIndent();
1093   printer.decreaseIndent();
1094 }
1095 
1096 LogicalResult transform::ForeachMatchOp::verify() {
1097   if (getMatchers().size() != getActions().size())
1098     return emitOpError() << "expected the same number of matchers and actions";
1099   if (getMatchers().empty())
1100     return emitOpError() << "expected at least one match/action pair";
1101 
1102   llvm::SmallPtrSet<Attribute, 8> matcherNames;
1103   for (Attribute name : getMatchers()) {
1104     if (matcherNames.insert(name).second)
1105       continue;
1106     emitWarning() << "matcher " << name
1107                   << " is used more than once, only the first match will apply";
1108   }
1109 
1110   return success();
1111 }
1112 
1113 /// Checks that the attributes of the function-like operation have correct
1114 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
1115 /// annotations being present even if they can be inferred from the body.
1116 static DiagnosedSilenceableFailure
1117 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1118                                      bool alsoVerifyInternal = false) {
1119   auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1120   llvm::SmallDenseSet<unsigned> consumedArguments;
1121   if (!op.isExternal()) {
1122     transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1123                                          consumedArguments);
1124   }
1125   for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1126     bool isConsumed =
1127         op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1128         nullptr;
1129     bool isReadOnly =
1130         op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1131         nullptr;
1132     if (isConsumed && isReadOnly) {
1133       return transformOp.emitSilenceableError()
1134              << "argument #" << i << " cannot be both readonly and consumed";
1135     }
1136     if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1137       return transformOp.emitSilenceableError()
1138              << "must provide consumed/readonly status for arguments of "
1139                 "external or called ops";
1140     }
1141     if (op.isExternal())
1142       continue;
1143 
1144     if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1145       return transformOp.emitSilenceableError()
1146              << "argument #" << i
1147              << " is consumed in the body but is not marked as such";
1148     }
1149     if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1150       // Cannot use op.emitWarning() here as it would attempt to verify the op
1151       // before printing, resulting in infinite recursion.
1152       emitWarning(op->getLoc())
1153           << "op argument #" << i
1154           << " is not consumed in the body but is marked as consumed";
1155     }
1156   }
1157   return DiagnosedSilenceableFailure::success();
1158 }
1159 
1160 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1161     SymbolTableCollection &symbolTable) {
1162   assert(getMatchers().size() == getActions().size());
1163   auto consumedAttr =
1164       StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1165   for (auto &&[matcher, action] :
1166        llvm::zip_equal(getMatchers(), getActions())) {
1167     auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1168         symbolTable.lookupNearestSymbolFrom(getOperation(),
1169                                             cast<SymbolRefAttr>(matcher)));
1170     auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1171         symbolTable.lookupNearestSymbolFrom(getOperation(),
1172                                             cast<SymbolRefAttr>(action)));
1173     if (!matcherSymbol ||
1174         !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1175       return emitError() << "unresolved matcher symbol " << matcher;
1176     if (!actionSymbol ||
1177         !isa<TransformOpInterface>(actionSymbol.getOperation()))
1178       return emitError() << "unresolved action symbol " << action;
1179 
1180     if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1181                                                     /*emitWarnings=*/false,
1182                                                     /*alsoVerifyInternal=*/true)
1183                    .checkAndReport())) {
1184       return failure();
1185     }
1186     if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
1187                                                     /*emitWarnings=*/false,
1188                                                     /*alsoVerifyInternal=*/true)
1189                    .checkAndReport())) {
1190       return failure();
1191     }
1192 
1193     ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes();
1194     ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
1195     if (matcherResults.size() != actionArguments.size()) {
1196       return emitError() << "mismatching number of matcher results and "
1197                             "action arguments between "
1198                          << matcher << " (" << matcherResults.size() << ") and "
1199                          << action << " (" << actionArguments.size() << ")";
1200     }
1201     for (auto &&[i, matcherType, actionType] :
1202          llvm::enumerate(matcherResults, actionArguments)) {
1203       if (implementSameTransformInterface(matcherType, actionType))
1204         continue;
1205 
1206       return emitError() << "mismatching type interfaces for matcher result "
1207                             "and action argument #"
1208                          << i;
1209     }
1210 
1211     if (!actionSymbol.getResultTypes().empty()) {
1212       InFlightDiagnostic diag =
1213           emitError() << "action symbol is not expected to have results";
1214       diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1215       return diag;
1216     }
1217 
1218     if (matcherSymbol.getArgumentTypes().size() != 1 ||
1219         !implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0],
1220                                          getRoot().getType())) {
1221       InFlightDiagnostic diag =
1222           emitOpError() << "expects matcher symbol to have one argument with "
1223                            "the same transform interface as the first operand";
1224       diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1225       return diag;
1226     }
1227 
1228     if (matcherSymbol.getArgAttr(0, consumedAttr)) {
1229       InFlightDiagnostic diag =
1230           emitOpError()
1231           << "does not expect matcher symbol to consume its operand";
1232       diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1233       return diag;
1234     }
1235   }
1236   return success();
1237 }
1238 
1239 //===----------------------------------------------------------------------===//
1240 // ForeachOp
1241 //===----------------------------------------------------------------------===//
1242 
1243 DiagnosedSilenceableFailure
1244 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1245                             transform::TransformResults &results,
1246                             transform::TransformState &state) {
1247   SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
1248   // Store payload ops in a vector because ops may be removed from the mapping
1249   // by the TrackingRewriter while the iteration is in progress.
1250   SmallVector<Operation *> targets =
1251       llvm::to_vector(state.getPayloadOps(getTarget()));
1252   for (Operation *op : targets) {
1253     auto scope = state.make_region_scope(getBody());
1254     if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
1255       return DiagnosedSilenceableFailure::definiteFailure();
1256 
1257     // Execute loop body.
1258     for (Operation &transform : getBody().front().without_terminator()) {
1259       DiagnosedSilenceableFailure result = state.applyTransform(
1260           cast<transform::TransformOpInterface>(transform));
1261       if (!result.succeeded())
1262         return result;
1263     }
1264 
1265     // Append yielded payload ops to result list (if any).
1266     for (unsigned i = 0; i < getNumResults(); ++i) {
1267       auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
1268       resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
1269     }
1270   }
1271 
1272   for (unsigned i = 0; i < getNumResults(); ++i)
1273     results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
1274 
1275   return DiagnosedSilenceableFailure::success();
1276 }
1277 
1278 void transform::ForeachOp::getEffects(
1279     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1280   BlockArgument iterVar = getIterationVariable();
1281   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1282         return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
1283       })) {
1284     consumesHandle(getTarget(), effects);
1285   } else {
1286     onlyReadsHandle(getTarget(), effects);
1287   }
1288 
1289   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1290         return doesModifyPayload(cast<TransformOpInterface>(&op));
1291       })) {
1292     modifiesPayload(effects);
1293   } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1294                return doesReadPayload(cast<TransformOpInterface>(&op));
1295              })) {
1296     onlyReadsPayload(effects);
1297   }
1298 
1299   for (Value result : getResults())
1300     producesHandle(result, effects);
1301 }
1302 
1303 void transform::ForeachOp::getSuccessorRegions(
1304     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1305   Region *bodyRegion = &getBody();
1306   if (point.isParent()) {
1307     regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1308     return;
1309   }
1310 
1311   // Branch back to the region or the parent.
1312   assert(point == getBody() && "unexpected region index");
1313   regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1314   regions.emplace_back();
1315 }
1316 
1317 OperandRange
1318 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1319   // The iteration variable op handle is mapped to a subset (one op to be
1320   // precise) of the payload ops of the ForeachOp operand.
1321   assert(point == getBody() && "unexpected region index");
1322   return getOperation()->getOperands();
1323 }
1324 
1325 transform::YieldOp transform::ForeachOp::getYieldOp() {
1326   return cast<transform::YieldOp>(getBody().front().getTerminator());
1327 }
1328 
1329 LogicalResult transform::ForeachOp::verify() {
1330   auto yieldOp = getYieldOp();
1331   if (getNumResults() != yieldOp.getNumOperands())
1332     return emitOpError() << "expects the same number of results as the "
1333                             "terminator has operands";
1334   for (Value v : yieldOp.getOperands())
1335     if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
1336       return yieldOp->emitOpError("expects operands to have types implementing "
1337                                   "TransformHandleTypeInterface");
1338   return success();
1339 }
1340 
1341 //===----------------------------------------------------------------------===//
1342 // GetParentOp
1343 //===----------------------------------------------------------------------===//
1344 
1345 DiagnosedSilenceableFailure
1346 transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1347                               transform::TransformResults &results,
1348                               transform::TransformState &state) {
1349   SmallVector<Operation *> parents;
1350   DenseSet<Operation *> resultSet;
1351   for (Operation *target : state.getPayloadOps(getTarget())) {
1352     Operation *parent = target;
1353     for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1354       parent = parent->getParentOp();
1355       while (parent) {
1356         bool checkIsolatedFromAbove =
1357             !getIsolatedFromAbove() ||
1358             parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
1359         bool checkOpName = !getOpName().has_value() ||
1360                            parent->getName().getStringRef() == *getOpName();
1361         if (checkIsolatedFromAbove && checkOpName)
1362           break;
1363         parent = parent->getParentOp();
1364       }
1365       if (!parent) {
1366         if (getAllowEmptyResults()) {
1367           results.set(llvm::cast<OpResult>(getResult()), parents);
1368           return DiagnosedSilenceableFailure::success();
1369         }
1370         DiagnosedSilenceableFailure diag =
1371             emitSilenceableError()
1372             << "could not find a parent op that matches all requirements";
1373         diag.attachNote(target->getLoc()) << "target op";
1374         return diag;
1375       }
1376     }
1377     if (getDeduplicate()) {
1378       if (!resultSet.contains(parent)) {
1379         parents.push_back(parent);
1380         resultSet.insert(parent);
1381       }
1382     } else {
1383       parents.push_back(parent);
1384     }
1385   }
1386   results.set(llvm::cast<OpResult>(getResult()), parents);
1387   return DiagnosedSilenceableFailure::success();
1388 }
1389 
1390 //===----------------------------------------------------------------------===//
1391 // GetConsumersOfResult
1392 //===----------------------------------------------------------------------===//
1393 
1394 DiagnosedSilenceableFailure
1395 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1396                                        transform::TransformResults &results,
1397                                        transform::TransformState &state) {
1398   int64_t resultNumber = getResultNumber();
1399   auto payloadOps = state.getPayloadOps(getTarget());
1400   if (std::empty(payloadOps)) {
1401     results.set(cast<OpResult>(getResult()), {});
1402     return DiagnosedSilenceableFailure::success();
1403   }
1404   if (!llvm::hasSingleElement(payloadOps))
1405     return emitDefiniteFailure()
1406            << "handle must be mapped to exactly one payload op";
1407 
1408   Operation *target = *payloadOps.begin();
1409   if (target->getNumResults() <= resultNumber)
1410     return emitDefiniteFailure() << "result number overflow";
1411   results.set(llvm::cast<OpResult>(getResult()),
1412               llvm::to_vector(target->getResult(resultNumber).getUsers()));
1413   return DiagnosedSilenceableFailure::success();
1414 }
1415 
1416 //===----------------------------------------------------------------------===//
1417 // GetDefiningOp
1418 //===----------------------------------------------------------------------===//
1419 
1420 DiagnosedSilenceableFailure
1421 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1422                                 transform::TransformResults &results,
1423                                 transform::TransformState &state) {
1424   SmallVector<Operation *> definingOps;
1425   for (Value v : state.getPayloadValues(getTarget())) {
1426     if (llvm::isa<BlockArgument>(v)) {
1427       DiagnosedSilenceableFailure diag =
1428           emitSilenceableError() << "cannot get defining op of block argument";
1429       diag.attachNote(v.getLoc()) << "target value";
1430       return diag;
1431     }
1432     definingOps.push_back(v.getDefiningOp());
1433   }
1434   results.set(llvm::cast<OpResult>(getResult()), definingOps);
1435   return DiagnosedSilenceableFailure::success();
1436 }
1437 
1438 //===----------------------------------------------------------------------===//
1439 // GetProducerOfOperand
1440 //===----------------------------------------------------------------------===//
1441 
1442 DiagnosedSilenceableFailure
1443 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1444                                        transform::TransformResults &results,
1445                                        transform::TransformState &state) {
1446   int64_t operandNumber = getOperandNumber();
1447   SmallVector<Operation *> producers;
1448   for (Operation *target : state.getPayloadOps(getTarget())) {
1449     Operation *producer =
1450         target->getNumOperands() <= operandNumber
1451             ? nullptr
1452             : target->getOperand(operandNumber).getDefiningOp();
1453     if (!producer) {
1454       DiagnosedSilenceableFailure diag =
1455           emitSilenceableError()
1456           << "could not find a producer for operand number: " << operandNumber
1457           << " of " << *target;
1458       diag.attachNote(target->getLoc()) << "target op";
1459       return diag;
1460     }
1461     producers.push_back(producer);
1462   }
1463   results.set(llvm::cast<OpResult>(getResult()), producers);
1464   return DiagnosedSilenceableFailure::success();
1465 }
1466 
1467 //===----------------------------------------------------------------------===//
1468 // GetOperandOp
1469 //===----------------------------------------------------------------------===//
1470 
1471 DiagnosedSilenceableFailure
1472 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1473                                transform::TransformResults &results,
1474                                transform::TransformState &state) {
1475   SmallVector<Value> operands;
1476   for (Operation *target : state.getPayloadOps(getTarget())) {
1477     SmallVector<int64_t> operandPositions;
1478     DiagnosedSilenceableFailure diag = expandTargetSpecification(
1479         getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1480         target->getNumOperands(), operandPositions);
1481     if (diag.isSilenceableFailure()) {
1482       diag.attachNote(target->getLoc())
1483           << "while considering positions of this payload operation";
1484       return diag;
1485     }
1486     llvm::append_range(operands,
1487                        llvm::map_range(operandPositions, [&](int64_t pos) {
1488                          return target->getOperand(pos);
1489                        }));
1490   }
1491   results.setValues(cast<OpResult>(getResult()), operands);
1492   return DiagnosedSilenceableFailure::success();
1493 }
1494 
1495 LogicalResult transform::GetOperandOp::verify() {
1496   return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1497                                     getIsInverted(), getIsAll());
1498 }
1499 
1500 //===----------------------------------------------------------------------===//
1501 // GetResultOp
1502 //===----------------------------------------------------------------------===//
1503 
1504 DiagnosedSilenceableFailure
1505 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1506                               transform::TransformResults &results,
1507                               transform::TransformState &state) {
1508   SmallVector<Value> opResults;
1509   for (Operation *target : state.getPayloadOps(getTarget())) {
1510     SmallVector<int64_t> resultPositions;
1511     DiagnosedSilenceableFailure diag = expandTargetSpecification(
1512         getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1513         target->getNumResults(), resultPositions);
1514     if (diag.isSilenceableFailure()) {
1515       diag.attachNote(target->getLoc())
1516           << "while considering positions of this payload operation";
1517       return diag;
1518     }
1519     llvm::append_range(opResults,
1520                        llvm::map_range(resultPositions, [&](int64_t pos) {
1521                          return target->getResult(pos);
1522                        }));
1523   }
1524   results.setValues(cast<OpResult>(getResult()), opResults);
1525   return DiagnosedSilenceableFailure::success();
1526 }
1527 
1528 LogicalResult transform::GetResultOp::verify() {
1529   return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1530                                     getIsInverted(), getIsAll());
1531 }
1532 
1533 //===----------------------------------------------------------------------===//
1534 // GetTypeOp
1535 //===----------------------------------------------------------------------===//
1536 
1537 void transform::GetTypeOp::getEffects(
1538     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1539   onlyReadsHandle(getValue(), effects);
1540   producesHandle(getResult(), effects);
1541   onlyReadsPayload(effects);
1542 }
1543 
1544 DiagnosedSilenceableFailure
1545 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1546                             transform::TransformResults &results,
1547                             transform::TransformState &state) {
1548   SmallVector<Attribute> params;
1549   for (Value value : state.getPayloadValues(getValue())) {
1550     Type type = value.getType();
1551     if (getElemental()) {
1552       if (auto shaped = dyn_cast<ShapedType>(type)) {
1553         type = shaped.getElementType();
1554       }
1555     }
1556     params.push_back(TypeAttr::get(type));
1557   }
1558   results.setParams(getResult().cast<OpResult>(), params);
1559   return DiagnosedSilenceableFailure::success();
1560 }
1561 
1562 //===----------------------------------------------------------------------===//
1563 // IncludeOp
1564 //===----------------------------------------------------------------------===//
1565 
1566 /// Applies the transform ops contained in `block`. Maps `results` to the same
1567 /// values as the operands of the block terminator.
1568 static DiagnosedSilenceableFailure
1569 applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
1570                    transform::TransformState &state,
1571                    transform::TransformResults &results) {
1572   // Apply the sequenced ops one by one.
1573   for (Operation &transform : block.without_terminator()) {
1574     DiagnosedSilenceableFailure result =
1575         state.applyTransform(cast<transform::TransformOpInterface>(transform));
1576     if (result.isDefiniteFailure())
1577       return result;
1578 
1579     if (result.isSilenceableFailure()) {
1580       if (mode == transform::FailurePropagationMode::Propagate) {
1581         // Propagate empty results in case of early exit.
1582         forwardEmptyOperands(&block, state, results);
1583         return result;
1584       }
1585       (void)result.silence();
1586     }
1587   }
1588 
1589   // Forward the operation mapping for values yielded from the sequence to the
1590   // values produced by the sequence op.
1591   transform::detail::forwardTerminatorOperands(&block, state, results);
1592   return DiagnosedSilenceableFailure::success();
1593 }
1594 
1595 DiagnosedSilenceableFailure
1596 transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
1597                             transform::TransformResults &results,
1598                             transform::TransformState &state) {
1599   auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1600       getOperation(), getTarget());
1601   assert(callee && "unverified reference to unknown symbol");
1602 
1603   if (callee.isExternal())
1604     return emitDefiniteFailure() << "unresolved external named sequence";
1605 
1606   // Map operands to block arguments.
1607   SmallVector<SmallVector<MappedValue>> mappings;
1608   detail::prepareValueMappings(mappings, getOperands(), state);
1609   auto scope = state.make_region_scope(callee.getBody());
1610   for (auto &&[arg, map] :
1611        llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1612     if (failed(state.mapBlockArgument(arg, map)))
1613       return DiagnosedSilenceableFailure::definiteFailure();
1614   }
1615 
1616   DiagnosedSilenceableFailure result = applySequenceBlock(
1617       callee.getBody().front(), getFailurePropagationMode(), state, results);
1618   mappings.clear();
1619   detail::prepareValueMappings(
1620       mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1621   for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1622     results.setMappedValues(result, mapping);
1623   return result;
1624 }
1625 
1626 static DiagnosedSilenceableFailure
1627 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
1628 
1629 void transform::IncludeOp::getEffects(
1630     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1631   // Always mark as modifying the payload.
1632   // TODO: a mechanism to annotate effects on payload. Even when all handles are
1633   // only read, the payload may still be modified, so we currently stay on the
1634   // conservative side and always indicate modification. This may prevent some
1635   // code reordering.
1636   modifiesPayload(effects);
1637 
1638   // Results are always produced.
1639   producesHandle(getResults(), effects);
1640 
1641   // Adds default effects to operands and results. This will be added if
1642   // preconditions fail so the trait verifier doesn't complain about missing
1643   // effects and the real precondition failure is reported later on.
1644   auto defaultEffects = [&] { onlyReadsHandle(getOperands(), effects); };
1645 
1646   // Bail if the callee is unknown. This may run as part of the verification
1647   // process before we verified the validity of the callee or of this op.
1648   auto target =
1649       getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1650   if (!target)
1651     return defaultEffects();
1652   auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1653       getOperation(), getTarget());
1654   if (!callee)
1655     return defaultEffects();
1656   DiagnosedSilenceableFailure earlyVerifierResult =
1657       verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
1658   if (!earlyVerifierResult.succeeded()) {
1659     (void)earlyVerifierResult.silence();
1660     return defaultEffects();
1661   }
1662 
1663   for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1664     if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1665       consumesHandle(getOperand(i), effects);
1666     else
1667       onlyReadsHandle(getOperand(i), effects);
1668   }
1669 }
1670 
1671 LogicalResult
1672 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1673   // Access through indirection and do additional checking because this may be
1674   // running before the main op verifier.
1675   auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
1676   if (!targetAttr)
1677     return emitOpError() << "expects a 'target' symbol reference attribute";
1678 
1679   auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
1680       *this, targetAttr);
1681   if (!target)
1682     return emitOpError() << "does not reference a named transform sequence";
1683 
1684   FunctionType fnType = target.getFunctionType();
1685   if (fnType.getNumInputs() != getNumOperands())
1686     return emitError("incorrect number of operands for callee");
1687 
1688   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1689     if (getOperand(i).getType() != fnType.getInput(i)) {
1690       return emitOpError("operand type mismatch: expected operand type ")
1691              << fnType.getInput(i) << ", but provided "
1692              << getOperand(i).getType() << " for operand number " << i;
1693     }
1694   }
1695 
1696   if (fnType.getNumResults() != getNumResults())
1697     return emitError("incorrect number of results for callee");
1698 
1699   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1700     Type resultType = getResult(i).getType();
1701     Type funcType = fnType.getResult(i);
1702     if (!implementSameTransformInterface(resultType, funcType)) {
1703       return emitOpError() << "type of result #" << i
1704                            << " must implement the same transform dialect "
1705                               "interface as the corresponding callee result";
1706     }
1707   }
1708 
1709   return verifyFunctionLikeConsumeAnnotations(
1710              cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
1711              /*alsoVerifyInternal=*/true)
1712       .checkAndReport();
1713 }
1714 
1715 //===----------------------------------------------------------------------===//
1716 // MatchOperationEmptyOp
1717 //===----------------------------------------------------------------------===//
1718 
1719 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
1720     ::std::optional<::mlir::Operation *> maybeCurrent,
1721     transform::TransformResults &results, transform::TransformState &state) {
1722   if (!maybeCurrent.has_value()) {
1723     DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
1724     return DiagnosedSilenceableFailure::success();
1725   }
1726   DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
1727   return emitSilenceableError() << "operation is not empty";
1728 }
1729 
1730 //===----------------------------------------------------------------------===//
1731 // MatchOperationNameOp
1732 //===----------------------------------------------------------------------===//
1733 
1734 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
1735     Operation *current, transform::TransformResults &results,
1736     transform::TransformState &state) {
1737   StringRef currentOpName = current->getName().getStringRef();
1738   for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1739     if (acceptedAttr.getValue() == currentOpName)
1740       return DiagnosedSilenceableFailure::success();
1741   }
1742   return emitSilenceableError() << "wrong operation name";
1743 }
1744 
1745 //===----------------------------------------------------------------------===//
1746 // MatchParamCmpIOp
1747 //===----------------------------------------------------------------------===//
1748 
1749 DiagnosedSilenceableFailure
1750 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
1751                                    transform::TransformResults &results,
1752                                    transform::TransformState &state) {
1753   auto signedAPIntAsString = [&](APInt value) {
1754     std::string str;
1755     llvm::raw_string_ostream os(str);
1756     value.print(os, /*isSigned=*/true);
1757     return os.str();
1758   };
1759 
1760   ArrayRef<Attribute> params = state.getParams(getParam());
1761   ArrayRef<Attribute> references = state.getParams(getReference());
1762 
1763   if (params.size() != references.size()) {
1764     return emitSilenceableError()
1765            << "parameters have different payload lengths (" << params.size()
1766            << " vs " << references.size() << ")";
1767   }
1768 
1769   for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
1770     auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1771     auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1772     if (!intAttr || !refAttr) {
1773       return emitDefiniteFailure()
1774              << "non-integer parameter value not expected";
1775     }
1776     if (intAttr.getType() != refAttr.getType()) {
1777       return emitDefiniteFailure()
1778              << "mismatching integer attribute types in parameter #" << i;
1779     }
1780     APInt value = intAttr.getValue();
1781     APInt refValue = refAttr.getValue();
1782 
1783     // TODO: this copy will not be necessary in C++20.
1784     int64_t position = i;
1785     auto reportError = [&](StringRef direction) {
1786       DiagnosedSilenceableFailure diag =
1787           emitSilenceableError() << "expected parameter to be " << direction
1788                                  << " " << signedAPIntAsString(refValue)
1789                                  << ", got " << signedAPIntAsString(value);
1790       diag.attachNote(getParam().getLoc())
1791           << "value # " << position
1792           << " associated with the parameter defined here";
1793       return diag;
1794     };
1795 
1796     switch (getPredicate()) {
1797     case MatchCmpIPredicate::eq:
1798       if (value.eq(refValue))
1799         break;
1800       return reportError("equal to");
1801     case MatchCmpIPredicate::ne:
1802       if (value.ne(refValue))
1803         break;
1804       return reportError("not equal to");
1805     case MatchCmpIPredicate::lt:
1806       if (value.slt(refValue))
1807         break;
1808       return reportError("less than");
1809     case MatchCmpIPredicate::le:
1810       if (value.sle(refValue))
1811         break;
1812       return reportError("less than or equal to");
1813     case MatchCmpIPredicate::gt:
1814       if (value.sgt(refValue))
1815         break;
1816       return reportError("greater than");
1817     case MatchCmpIPredicate::ge:
1818       if (value.sge(refValue))
1819         break;
1820       return reportError("greater than or equal to");
1821     }
1822   }
1823   return DiagnosedSilenceableFailure::success();
1824 }
1825 
1826 void transform::MatchParamCmpIOp::getEffects(
1827     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1828   onlyReadsHandle(getParam(), effects);
1829   onlyReadsHandle(getReference(), effects);
1830 }
1831 
1832 //===----------------------------------------------------------------------===//
1833 // ParamConstantOp
1834 //===----------------------------------------------------------------------===//
1835 
1836 DiagnosedSilenceableFailure
1837 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
1838                                   transform::TransformResults &results,
1839                                   transform::TransformState &state) {
1840   results.setParams(cast<OpResult>(getParam()), {getValue()});
1841   return DiagnosedSilenceableFailure::success();
1842 }
1843 
1844 //===----------------------------------------------------------------------===//
1845 // MergeHandlesOp
1846 //===----------------------------------------------------------------------===//
1847 
1848 DiagnosedSilenceableFailure
1849 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
1850                                  transform::TransformResults &results,
1851                                  transform::TransformState &state) {
1852   ValueRange handles = getHandles();
1853   if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
1854     SmallVector<Operation *> operations;
1855     for (Value operand : handles)
1856       llvm::append_range(operations, state.getPayloadOps(operand));
1857     if (!getDeduplicate()) {
1858       results.set(llvm::cast<OpResult>(getResult()), operations);
1859       return DiagnosedSilenceableFailure::success();
1860     }
1861 
1862     SetVector<Operation *> uniqued(operations.begin(), operations.end());
1863     results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
1864     return DiagnosedSilenceableFailure::success();
1865   }
1866 
1867   if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
1868     SmallVector<Attribute> attrs;
1869     for (Value attribute : handles)
1870       llvm::append_range(attrs, state.getParams(attribute));
1871     if (!getDeduplicate()) {
1872       results.setParams(cast<OpResult>(getResult()), attrs);
1873       return DiagnosedSilenceableFailure::success();
1874     }
1875 
1876     SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
1877     results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
1878     return DiagnosedSilenceableFailure::success();
1879   }
1880 
1881   assert(
1882       llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
1883       "expected value handle type");
1884   SmallVector<Value> payloadValues;
1885   for (Value value : handles)
1886     llvm::append_range(payloadValues, state.getPayloadValues(value));
1887   if (!getDeduplicate()) {
1888     results.setValues(cast<OpResult>(getResult()), payloadValues);
1889     return DiagnosedSilenceableFailure::success();
1890   }
1891 
1892   SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
1893   results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
1894   return DiagnosedSilenceableFailure::success();
1895 }
1896 
1897 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
1898   // Handles may be the same if deduplicating is enabled.
1899   return getDeduplicate();
1900 }
1901 
1902 void transform::MergeHandlesOp::getEffects(
1903     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1904   onlyReadsHandle(getHandles(), effects);
1905   producesHandle(getResult(), effects);
1906 
1907   // There are no effects on the Payload IR as this is only a handle
1908   // manipulation.
1909 }
1910 
1911 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
1912   if (getDeduplicate() || getHandles().size() != 1)
1913     return {};
1914 
1915   // If deduplication is not required and there is only one operand, it can be
1916   // used directly instead of merging.
1917   return getHandles().front();
1918 }
1919 
1920 //===----------------------------------------------------------------------===//
1921 // NamedSequenceOp
1922 //===----------------------------------------------------------------------===//
1923 
1924 DiagnosedSilenceableFailure
1925 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
1926                                   transform::TransformResults &results,
1927                                   transform::TransformState &state) {
1928   if (isExternal())
1929     return emitDefiniteFailure() << "unresolved external named sequence";
1930 
1931   // Map the entry block argument to the list of operations.
1932   // Note: this is the same implementation as PossibleTopLevelTransformOp but
1933   // without attaching the interface / trait since that is tailored to a
1934   // dangling top-level op that does not get "called".
1935   auto scope = state.make_region_scope(getBody());
1936   if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
1937           state, this->getOperation(), getBody())))
1938     return DiagnosedSilenceableFailure::definiteFailure();
1939 
1940   return applySequenceBlock(getBody().front(),
1941                             FailurePropagationMode::Propagate, state, results);
1942 }
1943 
1944 void transform::NamedSequenceOp::getEffects(
1945     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
1946 
1947 ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
1948                                               OperationState &result) {
1949   return function_interface_impl::parseFunctionOp(
1950       parser, result, /*allowVariadic=*/false,
1951       getFunctionTypeAttrName(result.name),
1952       [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
1953          function_interface_impl::VariadicFlag,
1954          std::string &) { return builder.getFunctionType(inputs, results); },
1955       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1956 }
1957 
1958 void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
1959   function_interface_impl::printFunctionOp(
1960       printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
1961       getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
1962       getResAttrsAttrName());
1963 }
1964 
1965 /// Verifies that a symbol function-like transform dialect operation has the
1966 /// signature and the terminator that have conforming types, i.e., types
1967 /// implementing the same transform dialect type interface. If `allowExternal`
1968 /// is set, allow external symbols (declarations) and don't check the terminator
1969 /// as it may not exist.
1970 static DiagnosedSilenceableFailure
1971 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
1972   if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
1973     DiagnosedSilenceableFailure diag =
1974         emitSilenceableFailure(op)
1975         << "cannot be defined inside another transform op";
1976     diag.attachNote(parent.getLoc()) << "ancestor transform op";
1977     return diag;
1978   }
1979 
1980   if (op.isExternal() || op.getFunctionBody().empty()) {
1981     if (allowExternal)
1982       return DiagnosedSilenceableFailure::success();
1983 
1984     return emitSilenceableFailure(op) << "cannot be external";
1985   }
1986 
1987   if (op.getFunctionBody().front().empty())
1988     return emitSilenceableFailure(op) << "expected a non-empty body block";
1989 
1990   Operation *terminator = &op.getFunctionBody().front().back();
1991   if (!isa<transform::YieldOp>(terminator)) {
1992     DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
1993                                        << "expected '"
1994                                        << transform::YieldOp::getOperationName()
1995                                        << "' as terminator";
1996     diag.attachNote(terminator->getLoc()) << "terminator";
1997     return diag;
1998   }
1999 
2000   if (terminator->getNumOperands() != op.getResultTypes().size()) {
2001     return emitSilenceableFailure(terminator)
2002            << "expected terminator to have as many operands as the parent op "
2003               "has results";
2004   }
2005   for (auto [i, operandType, resultType] : llvm::zip_equal(
2006            llvm::seq<unsigned>(0, terminator->getNumOperands()),
2007            terminator->getOperands().getType(), op.getResultTypes())) {
2008     if (operandType == resultType)
2009       continue;
2010     return emitSilenceableFailure(terminator)
2011            << "the type of the terminator operand #" << i
2012            << " must match the type of the corresponding parent op result ("
2013            << operandType << " vs " << resultType << ")";
2014   }
2015 
2016   return DiagnosedSilenceableFailure::success();
2017 }
2018 
2019 /// Verification of a NamedSequenceOp. This does not report the error
2020 /// immediately, so it can be used to check for op's well-formedness before the
2021 /// verifier runs, e.g., during trait verification.
2022 static DiagnosedSilenceableFailure
2023 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2024   if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2025     if (!parent->getAttr(
2026             transform::TransformDialect::kWithNamedSequenceAttrName)) {
2027       DiagnosedSilenceableFailure diag =
2028           emitSilenceableFailure(op)
2029           << "expects the parent symbol table to have the '"
2030           << transform::TransformDialect::kWithNamedSequenceAttrName
2031           << "' attribute";
2032       diag.attachNote(parent->getLoc()) << "symbol table operation";
2033       return diag;
2034     }
2035   }
2036 
2037   if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2038     DiagnosedSilenceableFailure diag =
2039         emitSilenceableFailure(op)
2040         << "cannot be defined inside another transform op";
2041     diag.attachNote(parent.getLoc()) << "ancestor transform op";
2042     return diag;
2043   }
2044 
2045   if (op.isExternal() || op.getBody().empty())
2046     return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2047                                                 emitWarnings);
2048 
2049   if (op.getBody().front().empty())
2050     return emitSilenceableFailure(op) << "expected a non-empty body block";
2051 
2052   Operation *terminator = &op.getBody().front().back();
2053   if (!isa<transform::YieldOp>(terminator)) {
2054     DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
2055                                        << "expected '"
2056                                        << transform::YieldOp::getOperationName()
2057                                        << "' as terminator";
2058     diag.attachNote(terminator->getLoc()) << "terminator";
2059     return diag;
2060   }
2061 
2062   if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2063     return emitSilenceableFailure(terminator)
2064            << "expected terminator to have as many operands as the parent op "
2065               "has results";
2066   }
2067   for (auto [i, operandType, resultType] :
2068        llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2069                        terminator->getOperands().getType(),
2070                        op.getFunctionType().getResults())) {
2071     if (operandType == resultType)
2072       continue;
2073     return emitSilenceableFailure(terminator)
2074            << "the type of the terminator operand #" << i
2075            << " must match the type of the corresponding parent op result ("
2076            << operandType << " vs " << resultType << ")";
2077   }
2078 
2079   auto funcOp = cast<FunctionOpInterface>(*op);
2080   DiagnosedSilenceableFailure diag =
2081       verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2082   if (!diag.succeeded())
2083     return diag;
2084 
2085   return verifyYieldingSingleBlockOp(funcOp,
2086                                      /*allowExternal=*/true);
2087 }
2088 
2089 LogicalResult transform::NamedSequenceOp::verify() {
2090   // Actual verification happens in a separate function for reusability.
2091   return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2092 }
2093 
2094 template <typename FnTy>
2095 static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2096                               Type bbArgType, TypeRange extraBindingTypes,
2097                               FnTy bodyBuilder) {
2098   SmallVector<Type> types;
2099   types.reserve(1 + extraBindingTypes.size());
2100   types.push_back(bbArgType);
2101   llvm::append_range(types, extraBindingTypes);
2102 
2103   OpBuilder::InsertionGuard guard(builder);
2104   Region *region = state.regions.back().get();
2105   Block *bodyBlock =
2106       builder.createBlock(region, region->begin(), types,
2107                           SmallVector<Location>(types.size(), state.location));
2108 
2109   // Populate body.
2110   builder.setInsertionPointToStart(bodyBlock);
2111   if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2112     bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2113   } else {
2114     bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2115                 bodyBlock->getArguments().drop_front());
2116   }
2117 }
2118 
2119 void transform::NamedSequenceOp::build(OpBuilder &builder,
2120                                        OperationState &state, StringRef symName,
2121                                        Type rootType, TypeRange resultTypes,
2122                                        SequenceBodyBuilderFn bodyBuilder,
2123                                        ArrayRef<NamedAttribute> attrs,
2124                                        ArrayRef<DictionaryAttr> argAttrs) {
2125   state.addAttribute(SymbolTable::getSymbolAttrName(),
2126                      builder.getStringAttr(symName));
2127   state.addAttribute(getFunctionTypeAttrName(state.name),
2128                      TypeAttr::get(FunctionType::get(builder.getContext(),
2129                                                      rootType, resultTypes)));
2130   state.attributes.append(attrs.begin(), attrs.end());
2131   state.addRegion();
2132 
2133   buildSequenceBody(builder, state, rootType,
2134                     /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2135 }
2136 
2137 //===----------------------------------------------------------------------===//
2138 // NumAssociationsOp
2139 //===----------------------------------------------------------------------===//
2140 
2141 DiagnosedSilenceableFailure
2142 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2143                                     transform::TransformResults &results,
2144                                     transform::TransformState &state) {
2145   size_t numAssociations =
2146       llvm::TypeSwitch<Type, size_t>(getHandle().getType())
2147           .Case([&](TransformHandleTypeInterface opHandle) {
2148             return llvm::range_size(state.getPayloadOps(getHandle()));
2149           })
2150           .Case([&](TransformValueHandleTypeInterface valueHandle) {
2151             return llvm::range_size(state.getPayloadValues(getHandle()));
2152           })
2153           .Case([&](TransformParamTypeInterface param) {
2154             return llvm::range_size(state.getParams(getHandle()));
2155           })
2156           .Default([](Type) {
2157             llvm_unreachable("unknown kind of transform dialect type");
2158             return 0;
2159           });
2160   results.setParams(getNum().cast<OpResult>(),
2161                     rewriter.getI64IntegerAttr(numAssociations));
2162   return DiagnosedSilenceableFailure::success();
2163 }
2164 
2165 LogicalResult transform::NumAssociationsOp::verify() {
2166   // Verify that the result type accepts an i64 attribute as payload.
2167   auto resultType = getNum().getType().cast<TransformParamTypeInterface>();
2168   return resultType
2169       .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2170       .checkAndReport();
2171 }
2172 
2173 //===----------------------------------------------------------------------===//
2174 // SelectOp
2175 //===----------------------------------------------------------------------===//
2176 
2177 DiagnosedSilenceableFailure
2178 transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2179                            transform::TransformResults &results,
2180                            transform::TransformState &state) {
2181   SmallVector<Operation *> result;
2182   auto payloadOps = state.getPayloadOps(getTarget());
2183   for (Operation *op : payloadOps) {
2184     if (op->getName().getStringRef() == getOpName())
2185       result.push_back(op);
2186   }
2187   results.set(cast<OpResult>(getResult()), result);
2188   return DiagnosedSilenceableFailure::success();
2189 }
2190 
2191 //===----------------------------------------------------------------------===//
2192 // SplitHandleOp
2193 //===----------------------------------------------------------------------===//
2194 
2195 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2196                                      Value target, int64_t numResultHandles) {
2197   result.addOperands(target);
2198   result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2199 }
2200 
2201 DiagnosedSilenceableFailure
2202 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2203                                 transform::TransformResults &results,
2204                                 transform::TransformState &state) {
2205   int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
2206   auto produceNumOpsError = [&]() {
2207     return emitSilenceableError()
2208            << getHandle() << " expected to contain " << this->getNumResults()
2209            << " payload ops but it contains " << numPayloadOps
2210            << " payload ops";
2211   };
2212 
2213   // Fail if there are more payload ops than results and no overflow result was
2214   // specified.
2215   if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
2216     return produceNumOpsError();
2217 
2218   // Fail if there are more results than payload ops. Unless:
2219   // - "fail_on_payload_too_small" is set to "false", or
2220   // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2221   if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
2222       !(numPayloadOps == 0 && getPassThroughEmptyHandle()))
2223     return produceNumOpsError();
2224 
2225   // Distribute payload ops.
2226   SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
2227   if (getOverflowResult())
2228     resultHandles[*getOverflowResult()].reserve(numPayloadOps -
2229                                                 getNumResults());
2230   for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
2231     int64_t resultNum = en.index();
2232     if (resultNum >= getNumResults())
2233       resultNum = *getOverflowResult();
2234     resultHandles[resultNum].push_back(en.value());
2235   }
2236 
2237   // Set transform op results.
2238   for (auto &&it : llvm::enumerate(resultHandles))
2239     results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
2240 
2241   return DiagnosedSilenceableFailure::success();
2242 }
2243 
2244 void transform::SplitHandleOp::getEffects(
2245     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2246   onlyReadsHandle(getHandle(), effects);
2247   producesHandle(getResults(), effects);
2248   // There are no effects on the Payload IR as this is only a handle
2249   // manipulation.
2250 }
2251 
2252 LogicalResult transform::SplitHandleOp::verify() {
2253   if (getOverflowResult().has_value() &&
2254       !(*getOverflowResult() < getNumResults()))
2255     return emitOpError("overflow_result is not a valid result index");
2256   return success();
2257 }
2258 
2259 //===----------------------------------------------------------------------===//
2260 // ReplicateOp
2261 //===----------------------------------------------------------------------===//
2262 
2263 DiagnosedSilenceableFailure
2264 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2265                               transform::TransformResults &results,
2266                               transform::TransformState &state) {
2267   unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2268   for (const auto &en : llvm::enumerate(getHandles())) {
2269     Value handle = en.value();
2270     if (isa<TransformHandleTypeInterface>(handle.getType())) {
2271       SmallVector<Operation *> current =
2272           llvm::to_vector(state.getPayloadOps(handle));
2273       SmallVector<Operation *> payload;
2274       payload.reserve(numRepetitions * current.size());
2275       for (unsigned i = 0; i < numRepetitions; ++i)
2276         llvm::append_range(payload, current);
2277       results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2278     } else {
2279       assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2280              "expected param type");
2281       ArrayRef<Attribute> current = state.getParams(handle);
2282       SmallVector<Attribute> params;
2283       params.reserve(numRepetitions * current.size());
2284       for (unsigned i = 0; i < numRepetitions; ++i)
2285         llvm::append_range(params, current);
2286       results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2287                         params);
2288     }
2289   }
2290   return DiagnosedSilenceableFailure::success();
2291 }
2292 
2293 void transform::ReplicateOp::getEffects(
2294     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2295   onlyReadsHandle(getPattern(), effects);
2296   onlyReadsHandle(getHandles(), effects);
2297   producesHandle(getReplicated(), effects);
2298 }
2299 
2300 //===----------------------------------------------------------------------===//
2301 // SequenceOp
2302 //===----------------------------------------------------------------------===//
2303 
2304 DiagnosedSilenceableFailure
2305 transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2306                              transform::TransformResults &results,
2307                              transform::TransformState &state) {
2308   // Map the entry block argument to the list of operations.
2309   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2310   if (failed(mapBlockArguments(state)))
2311     return DiagnosedSilenceableFailure::definiteFailure();
2312 
2313   return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2314                             results);
2315 }
2316 
2317 static ParseResult parseSequenceOpOperands(
2318     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2319     Type &rootType,
2320     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2321     SmallVectorImpl<Type> &extraBindingTypes) {
2322   OpAsmParser::UnresolvedOperand rootOperand;
2323   OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2324   if (!hasRoot.has_value()) {
2325     root = std::nullopt;
2326     return success();
2327   }
2328   if (failed(hasRoot.value()))
2329     return failure();
2330   root = rootOperand;
2331 
2332   if (succeeded(parser.parseOptionalComma())) {
2333     if (failed(parser.parseOperandList(extraBindings)))
2334       return failure();
2335   }
2336   if (failed(parser.parseColon()))
2337     return failure();
2338 
2339   // The paren is truly optional.
2340   (void)parser.parseOptionalLParen();
2341 
2342   if (failed(parser.parseType(rootType))) {
2343     return failure();
2344   }
2345 
2346   if (!extraBindings.empty()) {
2347     if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2348       return failure();
2349   }
2350 
2351   if (extraBindingTypes.size() != extraBindings.size()) {
2352     return parser.emitError(parser.getNameLoc(),
2353                             "expected types to be provided for all operands");
2354   }
2355 
2356   // The paren is truly optional.
2357   (void)parser.parseOptionalRParen();
2358   return success();
2359 }
2360 
2361 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
2362                                     Value root, Type rootType,
2363                                     ValueRange extraBindings,
2364                                     TypeRange extraBindingTypes) {
2365   if (!root)
2366     return;
2367 
2368   printer << root;
2369   bool hasExtras = !extraBindings.empty();
2370   if (hasExtras) {
2371     printer << ", ";
2372     printer.printOperands(extraBindings);
2373   }
2374 
2375   printer << " : ";
2376   if (hasExtras)
2377     printer << "(";
2378 
2379   printer << rootType;
2380   if (hasExtras) {
2381     printer << ", ";
2382     llvm::interleaveComma(extraBindingTypes, printer.getStream());
2383     printer << ")";
2384   }
2385 }
2386 
2387 /// Returns `true` if the given op operand may be consuming the handle value in
2388 /// the Transform IR. That is, if it may have a Free effect on it.
2389 static bool isValueUsePotentialConsumer(OpOperand &use) {
2390   // Conservatively assume the effect being present in absence of the interface.
2391   auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2392   if (!iface)
2393     return true;
2394 
2395   return isHandleConsumed(use.get(), iface);
2396 }
2397 
2398 LogicalResult
2399 checkDoubleConsume(Value value,
2400                    function_ref<InFlightDiagnostic()> reportError) {
2401   OpOperand *potentialConsumer = nullptr;
2402   for (OpOperand &use : value.getUses()) {
2403     if (!isValueUsePotentialConsumer(use))
2404       continue;
2405 
2406     if (!potentialConsumer) {
2407       potentialConsumer = &use;
2408       continue;
2409     }
2410 
2411     InFlightDiagnostic diag = reportError()
2412                               << " has more than one potential consumer";
2413     diag.attachNote(potentialConsumer->getOwner()->getLoc())
2414         << "used here as operand #" << potentialConsumer->getOperandNumber();
2415     diag.attachNote(use.getOwner()->getLoc())
2416         << "used here as operand #" << use.getOperandNumber();
2417     return diag;
2418   }
2419 
2420   return success();
2421 }
2422 
2423 LogicalResult transform::SequenceOp::verify() {
2424   assert(getBodyBlock()->getNumArguments() >= 1 &&
2425          "the number of arguments must have been verified to be more than 1 by "
2426          "PossibleTopLevelTransformOpTrait");
2427 
2428   if (!getRoot() && !getExtraBindings().empty()) {
2429     return emitOpError()
2430            << "does not expect extra operands when used as top-level";
2431   }
2432 
2433   // Check if a block argument has more than one consuming use.
2434   for (BlockArgument arg : getBodyBlock()->getArguments()) {
2435     if (failed(checkDoubleConsume(arg, [this, arg]() {
2436           return (emitOpError() << "block argument #" << arg.getArgNumber());
2437         }))) {
2438       return failure();
2439     }
2440   }
2441 
2442   // Check properties of the nested operations they cannot check themselves.
2443   for (Operation &child : *getBodyBlock()) {
2444     if (!isa<TransformOpInterface>(child) &&
2445         &child != &getBodyBlock()->back()) {
2446       InFlightDiagnostic diag =
2447           emitOpError()
2448           << "expected children ops to implement TransformOpInterface";
2449       diag.attachNote(child.getLoc()) << "op without interface";
2450       return diag;
2451     }
2452 
2453     for (OpResult result : child.getResults()) {
2454       auto report = [&]() {
2455         return (child.emitError() << "result #" << result.getResultNumber());
2456       };
2457       if (failed(checkDoubleConsume(result, report)))
2458         return failure();
2459     }
2460   }
2461 
2462   if (!getBodyBlock()->mightHaveTerminator())
2463     return emitOpError() << "expects to have a terminator in the body";
2464 
2465   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2466       getOperation()->getResultTypes()) {
2467     InFlightDiagnostic diag = emitOpError()
2468                               << "expects the types of the terminator operands "
2469                                  "to match the types of the result";
2470     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2471     return diag;
2472   }
2473   return success();
2474 }
2475 
2476 void transform::SequenceOp::getEffects(
2477     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2478   getPotentialTopLevelEffects(effects);
2479 }
2480 
2481 OperandRange
2482 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2483   assert(point == getBody() && "unexpected region index");
2484   if (getOperation()->getNumOperands() > 0)
2485     return getOperation()->getOperands();
2486   return OperandRange(getOperation()->operand_end(),
2487                       getOperation()->operand_end());
2488 }
2489 
2490 void transform::SequenceOp::getSuccessorRegions(
2491     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2492   if (point.isParent()) {
2493     Region *bodyRegion = &getBody();
2494     regions.emplace_back(bodyRegion, getNumOperands() != 0
2495                                          ? bodyRegion->getArguments()
2496                                          : Block::BlockArgListType());
2497     return;
2498   }
2499 
2500   assert(point == getBody() && "unexpected region index");
2501   regions.emplace_back(getOperation()->getResults());
2502 }
2503 
2504 void transform::SequenceOp::getRegionInvocationBounds(
2505     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2506   (void)operands;
2507   bounds.emplace_back(1, 1);
2508 }
2509 
2510 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2511                                   TypeRange resultTypes,
2512                                   FailurePropagationMode failurePropagationMode,
2513                                   Value root,
2514                                   SequenceBodyBuilderFn bodyBuilder) {
2515   build(builder, state, resultTypes, failurePropagationMode, root,
2516         /*extra_bindings=*/ValueRange());
2517   Type bbArgType = root.getType();
2518   buildSequenceBody(builder, state, bbArgType,
2519                     /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2520 }
2521 
2522 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2523                                   TypeRange resultTypes,
2524                                   FailurePropagationMode failurePropagationMode,
2525                                   Value root, ValueRange extraBindings,
2526                                   SequenceBodyBuilderArgsFn bodyBuilder) {
2527   build(builder, state, resultTypes, failurePropagationMode, root,
2528         extraBindings);
2529   buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
2530                     bodyBuilder);
2531 }
2532 
2533 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2534                                   TypeRange resultTypes,
2535                                   FailurePropagationMode failurePropagationMode,
2536                                   Type bbArgType,
2537                                   SequenceBodyBuilderFn bodyBuilder) {
2538   build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2539         /*extra_bindings=*/ValueRange());
2540   buildSequenceBody(builder, state, bbArgType,
2541                     /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2542 }
2543 
2544 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2545                                   TypeRange resultTypes,
2546                                   FailurePropagationMode failurePropagationMode,
2547                                   Type bbArgType, TypeRange extraBindingTypes,
2548                                   SequenceBodyBuilderArgsFn bodyBuilder) {
2549   build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2550         /*extra_bindings=*/ValueRange());
2551   buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
2552 }
2553 
2554 //===----------------------------------------------------------------------===//
2555 // PrintOp
2556 //===----------------------------------------------------------------------===//
2557 
2558 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2559                                StringRef name) {
2560   if (!name.empty())
2561     result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
2562 }
2563 
2564 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2565                                Value target, StringRef name) {
2566   result.addOperands({target});
2567   build(builder, result, name);
2568 }
2569 
2570 DiagnosedSilenceableFailure
2571 transform::PrintOp::apply(transform::TransformRewriter &rewriter,
2572                           transform::TransformResults &results,
2573                           transform::TransformState &state) {
2574   llvm::outs() << "[[[ IR printer: ";
2575   if (getName().has_value())
2576     llvm::outs() << *getName() << " ";
2577 
2578   if (!getTarget()) {
2579     llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
2580     return DiagnosedSilenceableFailure::success();
2581   }
2582 
2583   llvm::outs() << "]]]\n";
2584   for (Operation *target : state.getPayloadOps(getTarget()))
2585     llvm::outs() << *target << "\n";
2586 
2587   return DiagnosedSilenceableFailure::success();
2588 }
2589 
2590 void transform::PrintOp::getEffects(
2591     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2592   onlyReadsHandle(getTarget(), effects);
2593   onlyReadsPayload(effects);
2594 
2595   // There is no resource for stderr file descriptor, so just declare print
2596   // writes into the default resource.
2597   effects.emplace_back(MemoryEffects::Write::get());
2598 }
2599 
2600 //===----------------------------------------------------------------------===//
2601 // VerifyOp
2602 //===----------------------------------------------------------------------===//
2603 
2604 DiagnosedSilenceableFailure
2605 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
2606                                 Operation *target,
2607                                 transform::ApplyToEachResultList &results,
2608                                 transform::TransformState &state) {
2609   if (failed(::mlir::verify(target))) {
2610     DiagnosedDefiniteFailure diag = emitDefiniteFailure()
2611                                     << "failed to verify payload op";
2612     diag.attachNote(target->getLoc()) << "payload op";
2613     return diag;
2614   }
2615   return DiagnosedSilenceableFailure::success();
2616 }
2617 
2618 void transform::VerifyOp::getEffects(
2619     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2620   transform::onlyReadsHandle(getTarget(), effects);
2621 }
2622 
2623 //===----------------------------------------------------------------------===//
2624 // YieldOp
2625 //===----------------------------------------------------------------------===//
2626 
2627 void transform::YieldOp::getEffects(
2628     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2629   onlyReadsHandle(getOperands(), effects);
2630 }
2631