xref: /llvm-project/mlir/lib/Dialect/Transform/IR/TransformOps.cpp (revision f6bfbc87779ef2079e9b1356ac21381659f13fbb)
1c63d2b2cSMatthias Springer //===- TransformOps.cpp - Transform dialect operations --------------------===//
20eb403adSAlex Zinenko //
30eb403adSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40eb403adSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
50eb403adSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60eb403adSAlex Zinenko //
70eb403adSAlex Zinenko //===----------------------------------------------------------------------===//
80eb403adSAlex Zinenko 
90eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformOps.h"
10c63d2b2cSMatthias Springer 
110bb4d4d3SMatthias Springer #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12920c4612SNicolas Vasilache #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
130bb4d4d3SMatthias Springer #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
143fe7127dSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
1530f22429SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h"
166fe03096SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformTypes.h"
1791856b34SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
185a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1942b16035SQuinn Dawkins #include "mlir/IR/BuiltinAttributes.h"
2063c9d2b1SAlex Zinenko #include "mlir/IR/Diagnostics.h"
212c95ede4SMatthias Springer #include "mlir/IR/Dominance.h"
22e4b04b39SOleksandr "Alex" Zinenko #include "mlir/IR/OpImplementation.h"
23a8cfa7cbSJakub Kuderski #include "mlir/IR/OperationSupport.h"
2430f22429SAlex Zinenko #include "mlir/IR/PatternMatch.h"
257dfcd4b7SMatthias Springer #include "mlir/IR/Verifier.h"
2642b16035SQuinn Dawkins #include "mlir/Interfaces/CallInterfaces.h"
2773c3dff1SAlex Zinenko #include "mlir/Interfaces/ControlFlowInterfaces.h"
2834a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionImplementation.h"
29633d9184SOleksandr "Alex" Zinenko #include "mlir/Interfaces/FunctionInterfaces.h"
3018ec2030SMatthias Springer #include "mlir/Pass/Pass.h"
3118ec2030SMatthias Springer #include "mlir/Pass/PassManager.h"
3218ec2030SMatthias Springer #include "mlir/Pass/PassRegistry.h"
332c95ede4SMatthias Springer #include "mlir/Transforms/CSE.h"
34bcfdb3e4SMatthias Springer #include "mlir/Transforms/DialectConversion.h"
350b52fa90SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
36fa1a23a7SMatthias Springer #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
3742b16035SQuinn Dawkins #include "llvm/ADT/DenseSet.h"
38bba85ebdSAlex Zinenko #include "llvm/ADT/STLExtras.h"
3930f22429SAlex Zinenko #include "llvm/ADT/ScopeExit.h"
4063c9d2b1SAlex Zinenko #include "llvm/ADT/SmallPtrSet.h"
41f90b6090SOleksandr "Alex" Zinenko #include "llvm/ADT/TypeSwitch.h"
42e3890b7fSAlex Zinenko #include "llvm/Support/Debug.h"
4342b16035SQuinn Dawkins #include "llvm/Support/ErrorHandling.h"
4405423905SKazu Hirata #include <optional>
45e3890b7fSAlex Zinenko 
46e3890b7fSAlex Zinenko #define DEBUG_TYPE "transform-dialect"
47e3890b7fSAlex Zinenko #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
480eb403adSAlex Zinenko 
4963c9d2b1SAlex Zinenko #define DEBUG_TYPE_MATCHER "transform-matcher"
5063c9d2b1SAlex Zinenko #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
5163c9d2b1SAlex Zinenko #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
5263c9d2b1SAlex Zinenko 
530eb403adSAlex Zinenko using namespace mlir;
540eb403adSAlex Zinenko 
55b9e40cdeSAlex Zinenko static ParseResult parseSequenceOpOperands(
5605423905SKazu Hirata     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
57b9e40cdeSAlex Zinenko     Type &rootType,
58b9e40cdeSAlex Zinenko     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
59b9e40cdeSAlex Zinenko     SmallVectorImpl<Type> &extraBindingTypes);
60b9e40cdeSAlex Zinenko static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
61b9e40cdeSAlex Zinenko                                     Value root, Type rootType,
62b9e40cdeSAlex Zinenko                                     ValueRange extraBindings,
63b9e40cdeSAlex Zinenko                                     TypeRange extraBindingTypes);
6463c9d2b1SAlex Zinenko static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
6563c9d2b1SAlex Zinenko                                      ArrayAttr matchers, ArrayAttr actions);
6663c9d2b1SAlex Zinenko static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
6763c9d2b1SAlex Zinenko                                             ArrayAttr &matchers,
6863c9d2b1SAlex Zinenko                                             ArrayAttr &actions);
69b9e40cdeSAlex Zinenko 
7018ec2030SMatthias Springer /// Helper function to check if the given transform op is contained in (or
7118ec2030SMatthias Springer /// equal to) the given payload target op. In that case, an error is returned.
7218ec2030SMatthias Springer /// Transforming transform IR that is currently executing is generally unsafe.
7318ec2030SMatthias Springer static DiagnosedSilenceableFailure
7418ec2030SMatthias Springer ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
7518ec2030SMatthias Springer                                      Operation *payload) {
7618ec2030SMatthias Springer   Operation *transformAncestor = transform.getOperation();
7718ec2030SMatthias Springer   while (transformAncestor) {
7818ec2030SMatthias Springer     if (transformAncestor == payload) {
7918ec2030SMatthias Springer       DiagnosedDefiniteFailure diag =
8018ec2030SMatthias Springer           transform.emitDefiniteFailure()
8118ec2030SMatthias Springer           << "cannot apply transform to itself (or one of its ancestors)";
8218ec2030SMatthias Springer       diag.attachNote(payload->getLoc()) << "target payload op";
8318ec2030SMatthias Springer       return diag;
8418ec2030SMatthias Springer     }
8518ec2030SMatthias Springer     transformAncestor = transformAncestor->getParentOp();
8618ec2030SMatthias Springer   }
8718ec2030SMatthias Springer   return DiagnosedSilenceableFailure::success();
8818ec2030SMatthias Springer }
8918ec2030SMatthias Springer 
900eb403adSAlex Zinenko #define GET_OP_CLASSES
910eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
920eb403adSAlex Zinenko 
9330f22429SAlex Zinenko //===----------------------------------------------------------------------===//
94e3890b7fSAlex Zinenko // AlternativesOp
95e3890b7fSAlex Zinenko //===----------------------------------------------------------------------===//
96e3890b7fSAlex Zinenko 
974dd744acSMarkus Böck OperandRange
984dd744acSMarkus Böck transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
994dd744acSMarkus Böck   if (!point.isParent() && getOperation()->getNumOperands() == 1)
100e3890b7fSAlex Zinenko     return getOperation()->getOperands();
101e3890b7fSAlex Zinenko   return OperandRange(getOperation()->operand_end(),
102e3890b7fSAlex Zinenko                       getOperation()->operand_end());
103e3890b7fSAlex Zinenko }
104e3890b7fSAlex Zinenko 
105e3890b7fSAlex Zinenko void transform::AlternativesOp::getSuccessorRegions(
1064dd744acSMarkus Böck     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
107491d2701SKazu Hirata   for (Region &alternative : llvm::drop_begin(
1084dd744acSMarkus Böck            getAlternatives(),
1094dd744acSMarkus Böck            point.isParent() ? 0
1104dd744acSMarkus Böck                             : point.getRegionOrNull()->getRegionNumber() + 1)) {
111e3890b7fSAlex Zinenko     regions.emplace_back(&alternative, !getOperands().empty()
112e3890b7fSAlex Zinenko                                            ? alternative.getArguments()
113e3890b7fSAlex Zinenko                                            : Block::BlockArgListType());
114e3890b7fSAlex Zinenko   }
1154dd744acSMarkus Böck   if (!point.isParent())
116e3890b7fSAlex Zinenko     regions.emplace_back(getOperation()->getResults());
117e3890b7fSAlex Zinenko }
118e3890b7fSAlex Zinenko 
119e3890b7fSAlex Zinenko void transform::AlternativesOp::getRegionInvocationBounds(
120e3890b7fSAlex Zinenko     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
121e3890b7fSAlex Zinenko   (void)operands;
122e3890b7fSAlex Zinenko   // The region corresponding to the first alternative is always executed, the
123e3890b7fSAlex Zinenko   // remaining may or may not be executed.
124e3890b7fSAlex Zinenko   bounds.reserve(getNumRegions());
125e3890b7fSAlex Zinenko   bounds.emplace_back(1, 1);
126e3890b7fSAlex Zinenko   bounds.resize(getNumRegions(), InvocationBounds(0, 1));
127e3890b7fSAlex Zinenko }
128e3890b7fSAlex Zinenko 
129aa6a6c56SNicolas Vasilache static void forwardEmptyOperands(Block *block, transform::TransformState &state,
130aa6a6c56SNicolas Vasilache                                  transform::TransformResults &results) {
131aa6a6c56SNicolas Vasilache   for (const auto &res : block->getParentOp()->getOpResults())
132aa6a6c56SNicolas Vasilache     results.set(res, {});
133aa6a6c56SNicolas Vasilache }
134aa6a6c56SNicolas Vasilache 
1351d45282aSAlex Zinenko DiagnosedSilenceableFailure
136c63d2b2cSMatthias Springer transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
137c63d2b2cSMatthias Springer                                  transform::TransformResults &results,
138e3890b7fSAlex Zinenko                                  transform::TransformState &state) {
139e3890b7fSAlex Zinenko   SmallVector<Operation *> originals;
140e3890b7fSAlex Zinenko   if (Value scopeHandle = getScope())
141e3890b7fSAlex Zinenko     llvm::append_range(originals, state.getPayloadOps(scopeHandle));
142e3890b7fSAlex Zinenko   else
143e3890b7fSAlex Zinenko     originals.push_back(state.getTopLevel());
144e3890b7fSAlex Zinenko 
145e3890b7fSAlex Zinenko   for (Operation *original : originals) {
146e3890b7fSAlex Zinenko     if (original->isAncestor(getOperation())) {
147b0bf7fffSAlex Zinenko       auto diag = emitDefiniteFailure()
148b0bf7fffSAlex Zinenko                   << "scope must not contain the transforms being applied";
149e3890b7fSAlex Zinenko       diag.attachNote(original->getLoc()) << "scope";
150b0bf7fffSAlex Zinenko       return diag;
1511d45282aSAlex Zinenko     }
1521d45282aSAlex Zinenko     if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
153b0bf7fffSAlex Zinenko       auto diag = emitDefiniteFailure()
1541d45282aSAlex Zinenko                   << "only isolated-from-above ops can be alternative scopes";
1551d45282aSAlex Zinenko       diag.attachNote(original->getLoc()) << "scope";
156b0bf7fffSAlex Zinenko       return diag;
157e3890b7fSAlex Zinenko     }
158e3890b7fSAlex Zinenko   }
159e3890b7fSAlex Zinenko 
160e3890b7fSAlex Zinenko   for (Region &reg : getAlternatives()) {
161e3890b7fSAlex Zinenko     // Clone the scope operations and make the transforms in this alternative
162e3890b7fSAlex Zinenko     // region apply to them by virtue of mapping the block argument (the only
163e3890b7fSAlex Zinenko     // visible handle) to the cloned scope operations. This effectively prevents
164e3890b7fSAlex Zinenko     // the transformation from accessing any IR outside the scope.
165e3890b7fSAlex Zinenko     auto scope = state.make_region_scope(reg);
166e3890b7fSAlex Zinenko     auto clones = llvm::to_vector(
167e3890b7fSAlex Zinenko         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
168e3890b7fSAlex Zinenko     auto deleteClones = llvm::make_scope_exit([&] {
169e3890b7fSAlex Zinenko       for (Operation *clone : clones)
170e3890b7fSAlex Zinenko         clone->erase();
171e3890b7fSAlex Zinenko     });
172bba85ebdSAlex Zinenko     if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
173bba85ebdSAlex Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
174e3890b7fSAlex Zinenko 
175e3890b7fSAlex Zinenko     bool failed = false;
176e3890b7fSAlex Zinenko     for (Operation &transform : reg.front().without_terminator()) {
1771d45282aSAlex Zinenko       DiagnosedSilenceableFailure result =
178e3890b7fSAlex Zinenko           state.applyTransform(cast<TransformOpInterface>(transform));
1791d45282aSAlex Zinenko       if (result.isSilenceableFailure()) {
180e3890b7fSAlex Zinenko         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
181e3890b7fSAlex Zinenko                           << "\n");
182e3890b7fSAlex Zinenko         failed = true;
183e3890b7fSAlex Zinenko         break;
184e3890b7fSAlex Zinenko       }
185e3890b7fSAlex Zinenko 
186e3890b7fSAlex Zinenko       if (::mlir::failed(result.silence()))
1871d45282aSAlex Zinenko         return DiagnosedSilenceableFailure::definiteFailure();
188e3890b7fSAlex Zinenko     }
189e3890b7fSAlex Zinenko 
190e3890b7fSAlex Zinenko     // If all operations in the given alternative succeeded, no need to consider
191e3890b7fSAlex Zinenko     // the rest. Replace the original scoping operation with the clone on which
192e3890b7fSAlex Zinenko     // the transformations were performed.
193e3890b7fSAlex Zinenko     if (!failed) {
194e3890b7fSAlex Zinenko       // We will be using the clones, so cancel their scheduled deletion.
195e3890b7fSAlex Zinenko       deleteClones.release();
196905e9324SMatthias Springer       TrackingListener listener(state, *this);
19707fef178SMatthias Springer       IRRewriter rewriter(getContext(), &listener);
198e3890b7fSAlex Zinenko       for (const auto &kvp : llvm::zip(originals, clones)) {
199e3890b7fSAlex Zinenko         Operation *original = std::get<0>(kvp);
200e3890b7fSAlex Zinenko         Operation *clone = std::get<1>(kvp);
201e3890b7fSAlex Zinenko         original->getBlock()->getOperations().insert(original->getIterator(),
202e3890b7fSAlex Zinenko                                                      clone);
203e3890b7fSAlex Zinenko         rewriter.replaceOp(original, clone->getResults());
204e3890b7fSAlex Zinenko       }
20563c9d2b1SAlex Zinenko       detail::forwardTerminatorOperands(&reg.front(), state, results);
2061d45282aSAlex Zinenko       return DiagnosedSilenceableFailure::success();
207e3890b7fSAlex Zinenko     }
208e3890b7fSAlex Zinenko   }
2091d45282aSAlex Zinenko   return emitSilenceableError() << "all alternatives failed";
210e3890b7fSAlex Zinenko }
211e3890b7fSAlex Zinenko 
212d46afeefSAlex Zinenko void transform::AlternativesOp::getEffects(
213d46afeefSAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2142c1ae801Sdonald chen   consumesHandle(getOperation()->getOpOperands(), effects);
2152c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
216d46afeefSAlex Zinenko   for (Region *region : getRegions()) {
217d46afeefSAlex Zinenko     if (!region->empty())
218d46afeefSAlex Zinenko       producesHandle(region->front().getArguments(), effects);
219d46afeefSAlex Zinenko   }
220d46afeefSAlex Zinenko   modifiesPayload(effects);
221d46afeefSAlex Zinenko }
222d46afeefSAlex Zinenko 
223e3890b7fSAlex Zinenko LogicalResult transform::AlternativesOp::verify() {
224e3890b7fSAlex Zinenko   for (Region &alternative : getAlternatives()) {
225e3890b7fSAlex Zinenko     Block &block = alternative.front();
226e3890b7fSAlex Zinenko     Operation *terminator = block.getTerminator();
227e3890b7fSAlex Zinenko     if (terminator->getOperands().getTypes() != getResults().getTypes()) {
228e3890b7fSAlex Zinenko       InFlightDiagnostic diag = emitOpError()
229e3890b7fSAlex Zinenko                                 << "expects terminator operands to have the "
230e3890b7fSAlex Zinenko                                    "same type as results of the operation";
231e3890b7fSAlex Zinenko       diag.attachNote(terminator->getLoc()) << "terminator";
232e3890b7fSAlex Zinenko       return diag;
233e3890b7fSAlex Zinenko     }
234e3890b7fSAlex Zinenko   }
235e3890b7fSAlex Zinenko 
236e3890b7fSAlex Zinenko   return success();
237e3890b7fSAlex Zinenko }
238e3890b7fSAlex Zinenko 
239e3890b7fSAlex Zinenko //===----------------------------------------------------------------------===//
2405e7ac250SQuinn Dawkins // AnnotateOp
2415e7ac250SQuinn Dawkins //===----------------------------------------------------------------------===//
2425e7ac250SQuinn Dawkins 
2435e7ac250SQuinn Dawkins DiagnosedSilenceableFailure
244c63d2b2cSMatthias Springer transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
245c63d2b2cSMatthias Springer                              transform::TransformResults &results,
2465e7ac250SQuinn Dawkins                              transform::TransformState &state) {
2475e7ac250SQuinn Dawkins   SmallVector<Operation *> targets =
2485e7ac250SQuinn Dawkins       llvm::to_vector(state.getPayloadOps(getTarget()));
2495e7ac250SQuinn Dawkins 
2505e7ac250SQuinn Dawkins   Attribute attr = UnitAttr::get(getContext());
2515e7ac250SQuinn Dawkins   if (auto paramH = getParam()) {
2525e7ac250SQuinn Dawkins     ArrayRef<Attribute> params = state.getParams(paramH);
2535e7ac250SQuinn Dawkins     if (params.size() != 1) {
2545e7ac250SQuinn Dawkins       if (targets.size() != params.size()) {
2555e7ac250SQuinn Dawkins         return emitSilenceableError()
2565e7ac250SQuinn Dawkins                << "parameter and target have different payload lengths ("
2575e7ac250SQuinn Dawkins                << params.size() << " vs " << targets.size() << ")";
2585e7ac250SQuinn Dawkins       }
2595e7ac250SQuinn Dawkins       for (auto &&[target, attr] : llvm::zip_equal(targets, params))
2605e7ac250SQuinn Dawkins         target->setAttr(getName(), attr);
2615e7ac250SQuinn Dawkins       return DiagnosedSilenceableFailure::success();
2625e7ac250SQuinn Dawkins     }
2635e7ac250SQuinn Dawkins     attr = params[0];
2645e7ac250SQuinn Dawkins   }
265153661dbSMehdi Amini   for (auto *target : targets)
2665e7ac250SQuinn Dawkins     target->setAttr(getName(), attr);
2675e7ac250SQuinn Dawkins   return DiagnosedSilenceableFailure::success();
2685e7ac250SQuinn Dawkins }
2695e7ac250SQuinn Dawkins 
2705e7ac250SQuinn Dawkins void transform::AnnotateOp::getEffects(
2715e7ac250SQuinn Dawkins     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2722c1ae801Sdonald chen   onlyReadsHandle(getTargetMutable(), effects);
2732c1ae801Sdonald chen   onlyReadsHandle(getParamMutable(), effects);
2745e7ac250SQuinn Dawkins   modifiesPayload(effects);
2755e7ac250SQuinn Dawkins }
2765e7ac250SQuinn Dawkins 
2775e7ac250SQuinn Dawkins //===----------------------------------------------------------------------===//
2782c95ede4SMatthias Springer // ApplyCommonSubexpressionEliminationOp
2792c95ede4SMatthias Springer //===----------------------------------------------------------------------===//
2802c95ede4SMatthias Springer 
2812c95ede4SMatthias Springer DiagnosedSilenceableFailure
2822c95ede4SMatthias Springer transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
2832c95ede4SMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
2842c95ede4SMatthias Springer     ApplyToEachResultList &results, transform::TransformState &state) {
2852c95ede4SMatthias Springer   // Make sure that this transform is not applied to itself. Modifying the
2862c95ede4SMatthias Springer   // transform IR while it is being interpreted is generally dangerous.
2872c95ede4SMatthias Springer   DiagnosedSilenceableFailure payloadCheck =
2882c95ede4SMatthias Springer       ensurePayloadIsSeparateFromTransform(*this, target);
2892c95ede4SMatthias Springer   if (!payloadCheck.succeeded())
2902c95ede4SMatthias Springer     return payloadCheck;
2912c95ede4SMatthias Springer 
2922c95ede4SMatthias Springer   DominanceInfo domInfo;
2932c95ede4SMatthias Springer   mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
2942c95ede4SMatthias Springer   return DiagnosedSilenceableFailure::success();
2952c95ede4SMatthias Springer }
2962c95ede4SMatthias Springer 
2972c95ede4SMatthias Springer void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
2982c95ede4SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2992c1ae801Sdonald chen   transform::onlyReadsHandle(getTargetMutable(), effects);
3002c95ede4SMatthias Springer   transform::modifiesPayload(effects);
3012c95ede4SMatthias Springer }
3022c95ede4SMatthias Springer 
3032c95ede4SMatthias Springer //===----------------------------------------------------------------------===//
304c2d5d348SMatthias Springer // ApplyDeadCodeEliminationOp
305c2d5d348SMatthias Springer //===----------------------------------------------------------------------===//
306c2d5d348SMatthias Springer 
307c2d5d348SMatthias Springer DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
308c2d5d348SMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
309c2d5d348SMatthias Springer     ApplyToEachResultList &results, transform::TransformState &state) {
310c2d5d348SMatthias Springer   // Make sure that this transform is not applied to itself. Modifying the
311c2d5d348SMatthias Springer   // transform IR while it is being interpreted is generally dangerous.
312c2d5d348SMatthias Springer   DiagnosedSilenceableFailure payloadCheck =
313c2d5d348SMatthias Springer       ensurePayloadIsSeparateFromTransform(*this, target);
314c2d5d348SMatthias Springer   if (!payloadCheck.succeeded())
315c2d5d348SMatthias Springer     return payloadCheck;
316c2d5d348SMatthias Springer 
317c2d5d348SMatthias Springer   // Maintain a worklist of potentially dead ops.
318c2d5d348SMatthias Springer   SetVector<Operation *> worklist;
319c2d5d348SMatthias Springer 
320c2d5d348SMatthias Springer   // Helper function that adds all defining ops of used values (operands and
321c2d5d348SMatthias Springer   // operands of nested ops).
322c2d5d348SMatthias Springer   auto addDefiningOpsToWorklist = [&](Operation *op) {
323c2d5d348SMatthias Springer     op->walk([&](Operation *op) {
324c2d5d348SMatthias Springer       for (Value v : op->getOperands())
325c2d5d348SMatthias Springer         if (Operation *defOp = v.getDefiningOp())
326c2d5d348SMatthias Springer           if (target->isProperAncestor(defOp))
327c2d5d348SMatthias Springer             worklist.insert(defOp);
328c2d5d348SMatthias Springer     });
329c2d5d348SMatthias Springer   };
330c2d5d348SMatthias Springer 
331c2d5d348SMatthias Springer   // Helper function that erases an op.
332c2d5d348SMatthias Springer   auto eraseOp = [&](Operation *op) {
333c2d5d348SMatthias Springer     // Remove op and nested ops from the worklist.
334c2d5d348SMatthias Springer     op->walk([&](Operation *op) {
335153661dbSMehdi Amini       const auto *it = llvm::find(worklist, op);
336c2d5d348SMatthias Springer       if (it != worklist.end())
337c2d5d348SMatthias Springer         worklist.erase(it);
338c2d5d348SMatthias Springer     });
339c2d5d348SMatthias Springer     rewriter.eraseOp(op);
340c2d5d348SMatthias Springer   };
341c2d5d348SMatthias Springer 
342c2d5d348SMatthias Springer   // Initial walk over the IR.
343c2d5d348SMatthias Springer   target->walk<WalkOrder::PostOrder>([&](Operation *op) {
344c2d5d348SMatthias Springer     if (op != target && isOpTriviallyDead(op)) {
345c2d5d348SMatthias Springer       addDefiningOpsToWorklist(op);
346c2d5d348SMatthias Springer       eraseOp(op);
347c2d5d348SMatthias Springer     }
348c2d5d348SMatthias Springer   });
349c2d5d348SMatthias Springer 
350c2d5d348SMatthias Springer   // Erase all ops that have become dead.
351c2d5d348SMatthias Springer   while (!worklist.empty()) {
352c2d5d348SMatthias Springer     Operation *op = worklist.pop_back_val();
353c2d5d348SMatthias Springer     if (!isOpTriviallyDead(op))
354c2d5d348SMatthias Springer       continue;
355c2d5d348SMatthias Springer     addDefiningOpsToWorklist(op);
356c2d5d348SMatthias Springer     eraseOp(op);
357c2d5d348SMatthias Springer   }
358c2d5d348SMatthias Springer 
359c2d5d348SMatthias Springer   return DiagnosedSilenceableFailure::success();
360c2d5d348SMatthias Springer }
361c2d5d348SMatthias Springer 
362c2d5d348SMatthias Springer void transform::ApplyDeadCodeEliminationOp::getEffects(
363c2d5d348SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3642c1ae801Sdonald chen   transform::onlyReadsHandle(getTargetMutable(), effects);
365c2d5d348SMatthias Springer   transform::modifiesPayload(effects);
366c2d5d348SMatthias Springer }
367c2d5d348SMatthias Springer 
368c2d5d348SMatthias Springer //===----------------------------------------------------------------------===//
3690b52fa90SMatthias Springer // ApplyPatternsOp
3700b52fa90SMatthias Springer //===----------------------------------------------------------------------===//
3710b52fa90SMatthias Springer 
372c63d2b2cSMatthias Springer DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
373c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
374c63d2b2cSMatthias Springer     ApplyToEachResultList &results, transform::TransformState &state) {
375726d0767SMatthias Springer   // Make sure that this transform is not applied to itself. Modifying the
376726d0767SMatthias Springer   // transform IR while it is being interpreted is generally dangerous. Even
377726d0767SMatthias Springer   // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
378726d0767SMatthias Springer   // performs many additional simplifications such as dead code elimination.
37918ec2030SMatthias Springer   DiagnosedSilenceableFailure payloadCheck =
38018ec2030SMatthias Springer       ensurePayloadIsSeparateFromTransform(*this, target);
38118ec2030SMatthias Springer   if (!payloadCheck.succeeded())
38218ec2030SMatthias Springer     return payloadCheck;
383726d0767SMatthias Springer 
3840b52fa90SMatthias Springer   // Gather all specified patterns.
3850b52fa90SMatthias Springer   MLIRContext *ctx = target->getContext();
3860b52fa90SMatthias Springer   RewritePatternSet patterns(ctx);
3875a10f207SMatthias Springer   if (!getRegion().empty()) {
3885a10f207SMatthias Springer     for (Operation &op : getRegion().front()) {
389e55e36deSOleksandr "Alex" Zinenko       cast<transform::PatternDescriptorOpInterface>(&op)
390e55e36deSOleksandr "Alex" Zinenko           .populatePatternsWithState(patterns, state);
3915a10f207SMatthias Springer     }
3925a10f207SMatthias Springer   }
3930b52fa90SMatthias Springer 
3940b52fa90SMatthias Springer   // Configure the GreedyPatternRewriteDriver.
3950b52fa90SMatthias Springer   GreedyRewriteConfig config;
396c63d2b2cSMatthias Springer   config.listener =
397c63d2b2cSMatthias Springer       static_cast<RewriterBase::Listener *>(rewriter.getListener());
39820245ed4SMatthias Springer   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
3990b52fa90SMatthias Springer 
40037b26bf4SOleksandr "Alex" Zinenko   config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
40137b26bf4SOleksandr "Alex" Zinenko                              ? GreedyRewriteConfig::kNoLimit
40237b26bf4SOleksandr "Alex" Zinenko                              : getMaxIterations();
40337b26bf4SOleksandr "Alex" Zinenko   config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
40437b26bf4SOleksandr "Alex" Zinenko                               ? GreedyRewriteConfig::kNoLimit
40537b26bf4SOleksandr "Alex" Zinenko                               : getMaxNumRewrites();
40637b26bf4SOleksandr "Alex" Zinenko 
40720245ed4SMatthias Springer   // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
40820245ed4SMatthias Springer   // was requested, apply the greedy pattern rewrite only once. (The greedy
40920245ed4SMatthias Springer   // pattern rewrite driver already iterates to a fixpoint internally.)
41020245ed4SMatthias Springer   bool cseChanged = false;
41120245ed4SMatthias Springer   // One or two iterations should be sufficient. Stop iterating after a certain
41220245ed4SMatthias Springer   // threshold to make debugging easier.
41320245ed4SMatthias Springer   static const int64_t kNumMaxIterations = 50;
41420245ed4SMatthias Springer   int64_t iteration = 0;
41520245ed4SMatthias Springer   do {
416976d25edSMatthias Springer     LogicalResult result = failure();
417976d25edSMatthias Springer     if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
418976d25edSMatthias Springer       // Op is isolated from above. Apply patterns and also perform region
419976d25edSMatthias Springer       // simplification.
42009dfc571SJacques Pienaar       result = applyPatternsGreedily(target, frozenPatterns, config);
421976d25edSMatthias Springer     } else {
42220245ed4SMatthias Springer       // Manually gather list of ops because the other
42320245ed4SMatthias Springer       // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
42420245ed4SMatthias Springer       // from above. This way, patterns can be applied to ops that are not
42520245ed4SMatthias Springer       // isolated from above. Regions are not being simplified. Furthermore,
42620245ed4SMatthias Springer       // only a single greedy rewrite iteration is performed.
4270b52fa90SMatthias Springer       SmallVector<Operation *> ops;
4280b52fa90SMatthias Springer       target->walk([&](Operation *nestedOp) {
4290b52fa90SMatthias Springer         if (target != nestedOp)
4300b52fa90SMatthias Springer           ops.push_back(nestedOp);
4310b52fa90SMatthias Springer       });
43209dfc571SJacques Pienaar       result = applyOpPatternsGreedily(ops, frozenPatterns, config);
433976d25edSMatthias Springer     }
434976d25edSMatthias Springer 
4350b52fa90SMatthias Springer     // A failure typically indicates that the pattern application did not
4360b52fa90SMatthias Springer     // converge.
4370b52fa90SMatthias Springer     if (failed(result)) {
4380b52fa90SMatthias Springer       return emitSilenceableFailure(target)
4390b52fa90SMatthias Springer              << "greedy pattern application failed";
4400b52fa90SMatthias Springer     }
4410b52fa90SMatthias Springer 
44220245ed4SMatthias Springer     if (getApplyCse()) {
44320245ed4SMatthias Springer       DominanceInfo domInfo;
44420245ed4SMatthias Springer       mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
44520245ed4SMatthias Springer                                           &cseChanged);
44620245ed4SMatthias Springer     }
44720245ed4SMatthias Springer   } while (cseChanged && ++iteration < kNumMaxIterations);
44820245ed4SMatthias Springer 
44920245ed4SMatthias Springer   if (iteration == kNumMaxIterations)
45020245ed4SMatthias Springer     return emitDefiniteFailure() << "fixpoint iteration did not converge";
45120245ed4SMatthias Springer 
4520b52fa90SMatthias Springer   return DiagnosedSilenceableFailure::success();
4530b52fa90SMatthias Springer }
4540b52fa90SMatthias Springer 
4550b52fa90SMatthias Springer LogicalResult transform::ApplyPatternsOp::verify() {
4565a10f207SMatthias Springer   if (!getRegion().empty()) {
4575a10f207SMatthias Springer     for (Operation &op : getRegion().front()) {
4585a10f207SMatthias Springer       if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
4595a10f207SMatthias Springer         InFlightDiagnostic diag = emitOpError()
4605a10f207SMatthias Springer                                   << "expected children ops to implement "
4615a10f207SMatthias Springer                                      "PatternDescriptorOpInterface";
4625a10f207SMatthias Springer         diag.attachNote(op.getLoc()) << "op without interface";
4635a10f207SMatthias Springer         return diag;
4645a10f207SMatthias Springer       }
4655a10f207SMatthias Springer     }
4665a10f207SMatthias Springer   }
4670b52fa90SMatthias Springer   return success();
4680b52fa90SMatthias Springer }
4690b52fa90SMatthias Springer 
4700b52fa90SMatthias Springer void transform::ApplyPatternsOp::getEffects(
4710b52fa90SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4722c1ae801Sdonald chen   transform::onlyReadsHandle(getTargetMutable(), effects);
4730b52fa90SMatthias Springer   transform::modifiesPayload(effects);
4740b52fa90SMatthias Springer }
4750b52fa90SMatthias Springer 
476223a0f63SMatthias Springer void transform::ApplyPatternsOp::build(
477223a0f63SMatthias Springer     OpBuilder &builder, OperationState &result, Value target,
478c63d2b2cSMatthias Springer     function_ref<void(OpBuilder &, Location)> bodyBuilder) {
479223a0f63SMatthias Springer   result.addOperands(target);
480223a0f63SMatthias Springer 
481223a0f63SMatthias Springer   OpBuilder::InsertionGuard g(builder);
482223a0f63SMatthias Springer   Region *region = result.addRegion();
483223a0f63SMatthias Springer   builder.createBlock(region);
484223a0f63SMatthias Springer   if (bodyBuilder)
485223a0f63SMatthias Springer     bodyBuilder(builder, result.location);
486223a0f63SMatthias Springer }
487223a0f63SMatthias Springer 
4880b52fa90SMatthias Springer //===----------------------------------------------------------------------===//
4895a10f207SMatthias Springer // ApplyCanonicalizationPatternsOp
4905a10f207SMatthias Springer //===----------------------------------------------------------------------===//
4915a10f207SMatthias Springer 
4925a10f207SMatthias Springer void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
4935a10f207SMatthias Springer     RewritePatternSet &patterns) {
4945a10f207SMatthias Springer   MLIRContext *ctx = patterns.getContext();
4955a10f207SMatthias Springer   for (Dialect *dialect : ctx->getLoadedDialects())
4965a10f207SMatthias Springer     dialect->getCanonicalizationPatterns(patterns);
4975a10f207SMatthias Springer   for (RegisteredOperationName op : ctx->getRegisteredOperations())
4985a10f207SMatthias Springer     op.getCanonicalizationPatterns(patterns, ctx);
4995a10f207SMatthias Springer }
5005a10f207SMatthias Springer 
5015a10f207SMatthias Springer //===----------------------------------------------------------------------===//
502bcfdb3e4SMatthias Springer // ApplyConversionPatternsOp
503bcfdb3e4SMatthias Springer //===----------------------------------------------------------------------===//
504bcfdb3e4SMatthias Springer 
505bcfdb3e4SMatthias Springer DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
506bcfdb3e4SMatthias Springer     transform::TransformRewriter &rewriter,
507bcfdb3e4SMatthias Springer     transform::TransformResults &results, transform::TransformState &state) {
508bcfdb3e4SMatthias Springer   MLIRContext *ctx = getContext();
509bcfdb3e4SMatthias Springer 
51099475f5bSNicolas Vasilache   // Instantiate the default type converter if a type converter builder is
51199475f5bSNicolas Vasilache   // specified.
512bcfdb3e4SMatthias Springer   std::unique_ptr<TypeConverter> defaultTypeConverter;
51399475f5bSNicolas Vasilache   transform::TypeConverterBuilderOpInterface typeConverterBuilder =
51499475f5bSNicolas Vasilache       getDefaultTypeConverter();
51599475f5bSNicolas Vasilache   if (typeConverterBuilder)
51699475f5bSNicolas Vasilache     defaultTypeConverter = typeConverterBuilder.getTypeConverter();
517bcfdb3e4SMatthias Springer 
518bcfdb3e4SMatthias Springer   // Configure conversion target.
519920c4612SNicolas Vasilache   ConversionTarget conversionTarget(*getContext());
520bcfdb3e4SMatthias Springer   if (getLegalOps())
521bcfdb3e4SMatthias Springer     for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
522bcfdb3e4SMatthias Springer       conversionTarget.addLegalOp(
523bcfdb3e4SMatthias Springer           OperationName(cast<StringAttr>(attr).getValue(), ctx));
524bcfdb3e4SMatthias Springer   if (getIllegalOps())
525bcfdb3e4SMatthias Springer     for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
526bcfdb3e4SMatthias Springer       conversionTarget.addIllegalOp(
527bcfdb3e4SMatthias Springer           OperationName(cast<StringAttr>(attr).getValue(), ctx));
528bcfdb3e4SMatthias Springer   if (getLegalDialects())
529bcfdb3e4SMatthias Springer     for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
530bcfdb3e4SMatthias Springer       conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
531bcfdb3e4SMatthias Springer   if (getIllegalDialects())
532bcfdb3e4SMatthias Springer     for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
533bcfdb3e4SMatthias Springer       conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
534bcfdb3e4SMatthias Springer 
535bcfdb3e4SMatthias Springer   // Gather all specified patterns.
536bcfdb3e4SMatthias Springer   RewritePatternSet patterns(ctx);
53799475f5bSNicolas Vasilache   // Need to keep the converters alive until after pattern application because
53899475f5bSNicolas Vasilache   // the patterns take a reference to an object that would otherwise get out of
53999475f5bSNicolas Vasilache   // scope.
54099475f5bSNicolas Vasilache   SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
541bcfdb3e4SMatthias Springer   if (!getPatterns().empty()) {
542bcfdb3e4SMatthias Springer     for (Operation &op : getPatterns().front()) {
543bcfdb3e4SMatthias Springer       auto descriptor =
544bcfdb3e4SMatthias Springer           cast<transform::ConversionPatternDescriptorOpInterface>(&op);
545bcfdb3e4SMatthias Springer 
546bcfdb3e4SMatthias Springer       // Check if this pattern set specifies a type converter.
547bcfdb3e4SMatthias Springer       std::unique_ptr<TypeConverter> typeConverter =
548bcfdb3e4SMatthias Springer           descriptor.getTypeConverter();
549bcfdb3e4SMatthias Springer       TypeConverter *converter = nullptr;
550bcfdb3e4SMatthias Springer       if (typeConverter) {
55199475f5bSNicolas Vasilache         keepAliveConverters.emplace_back(std::move(typeConverter));
55299475f5bSNicolas Vasilache         converter = keepAliveConverters.back().get();
553bcfdb3e4SMatthias Springer       } else {
554bcfdb3e4SMatthias Springer         // No type converter specified: Use the default type converter.
555bcfdb3e4SMatthias Springer         if (!defaultTypeConverter) {
556bcfdb3e4SMatthias Springer           auto diag = emitDefiniteFailure()
557bcfdb3e4SMatthias Springer                       << "pattern descriptor does not specify type "
558bcfdb3e4SMatthias Springer                          "converter and apply_conversion_patterns op has "
559bcfdb3e4SMatthias Springer                          "no default type converter";
560bcfdb3e4SMatthias Springer           diag.attachNote(op.getLoc()) << "pattern descriptor op";
561bcfdb3e4SMatthias Springer           return diag;
562bcfdb3e4SMatthias Springer         }
563bcfdb3e4SMatthias Springer         converter = defaultTypeConverter.get();
564bcfdb3e4SMatthias Springer       }
565e2d39f79SChristopher Bate 
566e2d39f79SChristopher Bate       // Add descriptor-specific updates to the conversion target, which may
567e2d39f79SChristopher Bate       // depend on the final type converter. In structural converters, the
568e2d39f79SChristopher Bate       // legality of types dictates the dynamic legality of an operation.
569e2d39f79SChristopher Bate       descriptor.populateConversionTargetRules(*converter, conversionTarget);
570e2d39f79SChristopher Bate 
571bcfdb3e4SMatthias Springer       descriptor.populatePatterns(*converter, patterns);
572bcfdb3e4SMatthias Springer     }
573bcfdb3e4SMatthias Springer   }
574bcfdb3e4SMatthias Springer 
575c1029b6aSMatthias Springer   // Attach a tracking listener if handles should be preserved. We configure the
576c1029b6aSMatthias Springer   // listener to allow op replacements with different names, as conversion
577c1029b6aSMatthias Springer   // patterns typically replace ops with replacement ops that have a different
578c1029b6aSMatthias Springer   // name.
579c1029b6aSMatthias Springer   TrackingListenerConfig trackingConfig;
580c1029b6aSMatthias Springer   trackingConfig.requireMatchingReplacementOpName = false;
581c1029b6aSMatthias Springer   ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
582c1029b6aSMatthias Springer   ConversionConfig conversionConfig;
583c1029b6aSMatthias Springer   if (getPreserveHandles())
584c1029b6aSMatthias Springer     conversionConfig.listener = &trackingListener;
585c1029b6aSMatthias Springer 
58699475f5bSNicolas Vasilache   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
587bcfdb3e4SMatthias Springer   for (Operation *target : state.getPayloadOps(getTarget())) {
588bcfdb3e4SMatthias Springer     // Make sure that this transform is not applied to itself. Modifying the
589bcfdb3e4SMatthias Springer     // transform IR while it is being interpreted is generally dangerous.
590bcfdb3e4SMatthias Springer     DiagnosedSilenceableFailure payloadCheck =
591bcfdb3e4SMatthias Springer         ensurePayloadIsSeparateFromTransform(*this, target);
592bcfdb3e4SMatthias Springer     if (!payloadCheck.succeeded())
593bcfdb3e4SMatthias Springer       return payloadCheck;
594bcfdb3e4SMatthias Springer 
595bcfdb3e4SMatthias Springer     LogicalResult status = failure();
596bcfdb3e4SMatthias Springer     if (getPartialConversion()) {
597c1029b6aSMatthias Springer       status = applyPartialConversion(target, conversionTarget, frozenPatterns,
598c1029b6aSMatthias Springer                                       conversionConfig);
599bcfdb3e4SMatthias Springer     } else {
600c1029b6aSMatthias Springer       status = applyFullConversion(target, conversionTarget, frozenPatterns,
601c1029b6aSMatthias Springer                                    conversionConfig);
602bcfdb3e4SMatthias Springer     }
603bcfdb3e4SMatthias Springer 
604c1029b6aSMatthias Springer     // Check dialect conversion state.
605c1029b6aSMatthias Springer     DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
606bcfdb3e4SMatthias Springer     if (failed(status)) {
607c1029b6aSMatthias Springer       diag = emitSilenceableError() << "dialect conversion failed";
608bcfdb3e4SMatthias Springer       diag.attachNote(target->getLoc()) << "target op";
609bcfdb3e4SMatthias Springer     }
610c1029b6aSMatthias Springer 
611c1029b6aSMatthias Springer     // Check tracking listener error state.
612c1029b6aSMatthias Springer     DiagnosedSilenceableFailure trackingFailure =
613c1029b6aSMatthias Springer         trackingListener.checkAndResetError();
614c1029b6aSMatthias Springer     if (!trackingFailure.succeeded()) {
615c1029b6aSMatthias Springer       if (diag.succeeded()) {
616c1029b6aSMatthias Springer         // Tracking failure is the only failure.
617c1029b6aSMatthias Springer         return trackingFailure;
618c1029b6aSMatthias Springer       } else {
619c1029b6aSMatthias Springer         diag.attachNote() << "tracking listener also failed: "
620c1029b6aSMatthias Springer                           << trackingFailure.getMessage();
621c1029b6aSMatthias Springer         (void)trackingFailure.silence();
622c1029b6aSMatthias Springer       }
623c1029b6aSMatthias Springer     }
624c1029b6aSMatthias Springer 
625c1029b6aSMatthias Springer     if (!diag.succeeded())
626c1029b6aSMatthias Springer       return diag;
627bcfdb3e4SMatthias Springer   }
628bcfdb3e4SMatthias Springer 
629bcfdb3e4SMatthias Springer   return DiagnosedSilenceableFailure::success();
630bcfdb3e4SMatthias Springer }
631bcfdb3e4SMatthias Springer 
632bcfdb3e4SMatthias Springer LogicalResult transform::ApplyConversionPatternsOp::verify() {
633bcfdb3e4SMatthias Springer   if (getNumRegions() != 1 && getNumRegions() != 2)
634bcfdb3e4SMatthias Springer     return emitOpError() << "expected 1 or 2 regions";
635bcfdb3e4SMatthias Springer   if (!getPatterns().empty()) {
636bcfdb3e4SMatthias Springer     for (Operation &op : getPatterns().front()) {
637bcfdb3e4SMatthias Springer       if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
638bcfdb3e4SMatthias Springer         InFlightDiagnostic diag =
639bcfdb3e4SMatthias Springer             emitOpError() << "expected pattern children ops to implement "
640bcfdb3e4SMatthias Springer                              "ConversionPatternDescriptorOpInterface";
641bcfdb3e4SMatthias Springer         diag.attachNote(op.getLoc()) << "op without interface";
642bcfdb3e4SMatthias Springer         return diag;
643bcfdb3e4SMatthias Springer       }
644bcfdb3e4SMatthias Springer     }
645bcfdb3e4SMatthias Springer   }
646bcfdb3e4SMatthias Springer   if (getNumRegions() == 2) {
647bcfdb3e4SMatthias Springer     Region &typeConverterRegion = getRegion(1);
648bcfdb3e4SMatthias Springer     if (!llvm::hasSingleElement(typeConverterRegion.front()))
649bcfdb3e4SMatthias Springer       return emitOpError()
650bcfdb3e4SMatthias Springer              << "expected exactly one op in default type converter region";
6514527adc5SDaniel Kuts     Operation *maybeTypeConverter = &typeConverterRegion.front().front();
6527ec88f06SMatthias Springer     auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
6534527adc5SDaniel Kuts         maybeTypeConverter);
6547ec88f06SMatthias Springer     if (!typeConverterOp) {
655bcfdb3e4SMatthias Springer       InFlightDiagnostic diag = emitOpError()
656bcfdb3e4SMatthias Springer                                 << "expected default converter child op to "
657bcfdb3e4SMatthias Springer                                    "implement TypeConverterBuilderOpInterface";
6584527adc5SDaniel Kuts       diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
659bcfdb3e4SMatthias Springer       return diag;
660bcfdb3e4SMatthias Springer     }
6617ec88f06SMatthias Springer     // Check default type converter type.
6627ec88f06SMatthias Springer     if (!getPatterns().empty()) {
6637ec88f06SMatthias Springer       for (Operation &op : getPatterns().front()) {
6647ec88f06SMatthias Springer         auto descriptor =
6657ec88f06SMatthias Springer             cast<transform::ConversionPatternDescriptorOpInterface>(&op);
6667ec88f06SMatthias Springer         if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
6677ec88f06SMatthias Springer           return failure();
6687ec88f06SMatthias Springer       }
6697ec88f06SMatthias Springer     }
670bcfdb3e4SMatthias Springer   }
671bcfdb3e4SMatthias Springer   return success();
672bcfdb3e4SMatthias Springer }
673bcfdb3e4SMatthias Springer 
674bcfdb3e4SMatthias Springer void transform::ApplyConversionPatternsOp::getEffects(
675bcfdb3e4SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
676c1029b6aSMatthias Springer   if (!getPreserveHandles()) {
6772c1ae801Sdonald chen     transform::consumesHandle(getTargetMutable(), effects);
678c1029b6aSMatthias Springer   } else {
6792c1ae801Sdonald chen     transform::onlyReadsHandle(getTargetMutable(), effects);
680c1029b6aSMatthias Springer   }
681bcfdb3e4SMatthias Springer   transform::modifiesPayload(effects);
682bcfdb3e4SMatthias Springer }
683bcfdb3e4SMatthias Springer 
684bcfdb3e4SMatthias Springer void transform::ApplyConversionPatternsOp::build(
685bcfdb3e4SMatthias Springer     OpBuilder &builder, OperationState &result, Value target,
686bcfdb3e4SMatthias Springer     function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
687bcfdb3e4SMatthias Springer     function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
688bcfdb3e4SMatthias Springer   result.addOperands(target);
689bcfdb3e4SMatthias Springer 
690bcfdb3e4SMatthias Springer   {
691bcfdb3e4SMatthias Springer     OpBuilder::InsertionGuard g(builder);
692bcfdb3e4SMatthias Springer     Region *region1 = result.addRegion();
693bcfdb3e4SMatthias Springer     builder.createBlock(region1);
694bcfdb3e4SMatthias Springer     if (patternsBodyBuilder)
695bcfdb3e4SMatthias Springer       patternsBodyBuilder(builder, result.location);
696bcfdb3e4SMatthias Springer   }
697bcfdb3e4SMatthias Springer   {
698bcfdb3e4SMatthias Springer     OpBuilder::InsertionGuard g(builder);
699bcfdb3e4SMatthias Springer     Region *region2 = result.addRegion();
700bcfdb3e4SMatthias Springer     builder.createBlock(region2);
701bcfdb3e4SMatthias Springer     if (typeConverterBodyBuilder)
702bcfdb3e4SMatthias Springer       typeConverterBodyBuilder(builder, result.location);
703bcfdb3e4SMatthias Springer   }
704bcfdb3e4SMatthias Springer }
705bcfdb3e4SMatthias Springer 
706bcfdb3e4SMatthias Springer //===----------------------------------------------------------------------===//
7070bb4d4d3SMatthias Springer // ApplyToLLVMConversionPatternsOp
7080bb4d4d3SMatthias Springer //===----------------------------------------------------------------------===//
7090bb4d4d3SMatthias Springer 
7100bb4d4d3SMatthias Springer void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
7110bb4d4d3SMatthias Springer     TypeConverter &typeConverter, RewritePatternSet &patterns) {
7120bb4d4d3SMatthias Springer   Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
7130bb4d4d3SMatthias Springer   assert(dialect && "expected that dialect is loaded");
714153661dbSMehdi Amini   auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
7150bb4d4d3SMatthias Springer   // ConversionTarget is currently ignored because the enclosing
7160bb4d4d3SMatthias Springer   // apply_conversion_patterns op sets up its own ConversionTarget.
7170bb4d4d3SMatthias Springer   ConversionTarget target(*getContext());
7180bb4d4d3SMatthias Springer   iface->populateConvertToLLVMConversionPatterns(
7190bb4d4d3SMatthias Springer       target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
7200bb4d4d3SMatthias Springer }
7210bb4d4d3SMatthias Springer 
7220bb4d4d3SMatthias Springer LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
7230bb4d4d3SMatthias Springer     transform::TypeConverterBuilderOpInterface builder) {
7240bb4d4d3SMatthias Springer   if (builder.getTypeConverterType() != "LLVMTypeConverter")
7250bb4d4d3SMatthias Springer     return emitOpError("expected LLVMTypeConverter");
7260bb4d4d3SMatthias Springer   return success();
7270bb4d4d3SMatthias Springer }
7280bb4d4d3SMatthias Springer 
7290bb4d4d3SMatthias Springer LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
7300bb4d4d3SMatthias Springer   Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
7310bb4d4d3SMatthias Springer   if (!dialect)
7320bb4d4d3SMatthias Springer     return emitOpError("unknown dialect or dialect not loaded: ")
7330bb4d4d3SMatthias Springer            << getDialectName();
734153661dbSMehdi Amini   auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
7350bb4d4d3SMatthias Springer   if (!iface)
7360bb4d4d3SMatthias Springer     return emitOpError(
7370bb4d4d3SMatthias Springer                "dialect does not implement ConvertToLLVMPatternInterface or "
7380bb4d4d3SMatthias Springer                "extension was not loaded: ")
7390bb4d4d3SMatthias Springer            << getDialectName();
7400bb4d4d3SMatthias Springer   return success();
7410bb4d4d3SMatthias Springer }
7420bb4d4d3SMatthias Springer 
7430bb4d4d3SMatthias Springer //===----------------------------------------------------------------------===//
744fa1a23a7SMatthias Springer // ApplyLoopInvariantCodeMotionOp
745fa1a23a7SMatthias Springer //===----------------------------------------------------------------------===//
746fa1a23a7SMatthias Springer 
747fa1a23a7SMatthias Springer DiagnosedSilenceableFailure
748fa1a23a7SMatthias Springer transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
749fa1a23a7SMatthias Springer     transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
750fa1a23a7SMatthias Springer     transform::ApplyToEachResultList &results,
751fa1a23a7SMatthias Springer     transform::TransformState &state) {
752fa1a23a7SMatthias Springer   // Currently, LICM does not remove operations, so we don't need tracking.
753fa1a23a7SMatthias Springer   // If this ever changes, add a LICM entry point that takes a rewriter.
754fa1a23a7SMatthias Springer   moveLoopInvariantCode(target);
755fa1a23a7SMatthias Springer   return DiagnosedSilenceableFailure::success();
756fa1a23a7SMatthias Springer }
757fa1a23a7SMatthias Springer 
758fa1a23a7SMatthias Springer void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
759fa1a23a7SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
7602c1ae801Sdonald chen   transform::onlyReadsHandle(getTargetMutable(), effects);
761fa1a23a7SMatthias Springer   transform::modifiesPayload(effects);
762fa1a23a7SMatthias Springer }
763fa1a23a7SMatthias Springer 
764fa1a23a7SMatthias Springer //===----------------------------------------------------------------------===//
76518ec2030SMatthias Springer // ApplyRegisteredPassOp
76618ec2030SMatthias Springer //===----------------------------------------------------------------------===//
76718ec2030SMatthias Springer 
768c63d2b2cSMatthias Springer DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
769c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
770c63d2b2cSMatthias Springer     ApplyToEachResultList &results, transform::TransformState &state) {
77118ec2030SMatthias Springer   // Make sure that this transform is not applied to itself. Modifying the
77218ec2030SMatthias Springer   // transform IR while it is being interpreted is generally dangerous. Even
77318ec2030SMatthias Springer   // more so when applying passes because they may perform a wide range of IR
77418ec2030SMatthias Springer   // modifications.
77518ec2030SMatthias Springer   DiagnosedSilenceableFailure payloadCheck =
77618ec2030SMatthias Springer       ensurePayloadIsSeparateFromTransform(*this, target);
77718ec2030SMatthias Springer   if (!payloadCheck.succeeded())
77818ec2030SMatthias Springer     return payloadCheck;
77918ec2030SMatthias Springer 
7802f8690b1SMatthias Springer   // Get pass or pass pipeline from registry.
7812f8690b1SMatthias Springer   const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
7822f8690b1SMatthias Springer   if (!info)
7832f8690b1SMatthias Springer     info = PassInfo::lookup(getPassName());
7842f8690b1SMatthias Springer   if (!info)
7852f8690b1SMatthias Springer     return emitDefiniteFailure()
7862f8690b1SMatthias Springer            << "unknown pass or pass pipeline: " << getPassName();
78718ec2030SMatthias Springer 
7882f8690b1SMatthias Springer   // Create pass manager and run the pass or pass pipeline.
78918ec2030SMatthias Springer   PassManager pm(getContext());
7902f8690b1SMatthias Springer   if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
79118ec2030SMatthias Springer         emitError(msg);
79218ec2030SMatthias Springer         return failure();
79318ec2030SMatthias Springer       }))) {
79418ec2030SMatthias Springer     return emitDefiniteFailure()
7952f8690b1SMatthias Springer            << "failed to add pass or pass pipeline to pipeline: "
7962f8690b1SMatthias Springer            << getPassName();
79718ec2030SMatthias Springer   }
79818ec2030SMatthias Springer   if (failed(pm.run(target))) {
79918ec2030SMatthias Springer     auto diag = emitSilenceableError() << "pass pipeline failed";
80018ec2030SMatthias Springer     diag.attachNote(target->getLoc()) << "target op";
80118ec2030SMatthias Springer     return diag;
80218ec2030SMatthias Springer   }
80318ec2030SMatthias Springer 
80418ec2030SMatthias Springer   results.push_back(target);
80518ec2030SMatthias Springer   return DiagnosedSilenceableFailure::success();
80618ec2030SMatthias Springer }
80718ec2030SMatthias Springer 
80818ec2030SMatthias Springer //===----------------------------------------------------------------------===//
8092e5fe721SLorenzo Chelini // CastOp
810bffec215SMatthias Springer //===----------------------------------------------------------------------===//
811bffec215SMatthias Springer 
812bffec215SMatthias Springer DiagnosedSilenceableFailure
813c63d2b2cSMatthias Springer transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
814c63d2b2cSMatthias Springer                               Operation *target, ApplyToEachResultList &results,
815bba85ebdSAlex Zinenko                               transform::TransformState &state) {
816bba85ebdSAlex Zinenko   results.push_back(target);
817bba85ebdSAlex Zinenko   return DiagnosedSilenceableFailure::success();
818bba85ebdSAlex Zinenko }
819bba85ebdSAlex Zinenko 
820bba85ebdSAlex Zinenko void transform::CastOp::getEffects(
821bba85ebdSAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
822bba85ebdSAlex Zinenko   onlyReadsPayload(effects);
8232c1ae801Sdonald chen   onlyReadsHandle(getInputMutable(), effects);
8242c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
825bba85ebdSAlex Zinenko }
826bba85ebdSAlex Zinenko 
827bba85ebdSAlex Zinenko bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
828bba85ebdSAlex Zinenko   assert(inputs.size() == 1 && "expected one input");
829bba85ebdSAlex Zinenko   assert(outputs.size() == 1 && "expected one output");
830bba85ebdSAlex Zinenko   return llvm::all_of(
831bba85ebdSAlex Zinenko       std::initializer_list<Type>{inputs.front(), outputs.front()},
832971b8525SJakub Kuderski       llvm::IsaPred<transform::TransformHandleTypeInterface>);
833bba85ebdSAlex Zinenko }
834bba85ebdSAlex Zinenko 
835bba85ebdSAlex Zinenko //===----------------------------------------------------------------------===//
836633d9184SOleksandr "Alex" Zinenko // CollectMatchingOp
83763c9d2b1SAlex Zinenko //===----------------------------------------------------------------------===//
83863c9d2b1SAlex Zinenko 
839e4b04b39SOleksandr "Alex" Zinenko /// Applies matcher operations from the given `block` using
840e4b04b39SOleksandr "Alex" Zinenko /// `blockArgumentMapping` to initialize block arguments. Updates `state`
841e4b04b39SOleksandr "Alex" Zinenko /// accordingly. If any of the matcher produces a silenceable failure, discards
842e4b04b39SOleksandr "Alex" Zinenko /// it (printing the content to the debug output stream) and returns failure. If
843e4b04b39SOleksandr "Alex" Zinenko /// any of the matchers produces a definite failure, reports it and returns
844e4b04b39SOleksandr "Alex" Zinenko /// failure. If all matchers in the block succeed, populates `mappings` with the
845e4b04b39SOleksandr "Alex" Zinenko /// payload entities associated with the block terminator operands. Note that
846e4b04b39SOleksandr "Alex" Zinenko /// `mappings` will be cleared before that.
84763c9d2b1SAlex Zinenko static DiagnosedSilenceableFailure
848e4b04b39SOleksandr "Alex" Zinenko matchBlock(Block &block,
849e4b04b39SOleksandr "Alex" Zinenko            ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
850e4b04b39SOleksandr "Alex" Zinenko            transform::TransformState &state,
85163c9d2b1SAlex Zinenko            SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
85263c9d2b1SAlex Zinenko   assert(block.getParent() && "cannot match using a detached block");
85322259281SMatthias Springer   auto matchScope = state.make_region_scope(*block.getParent());
854e4b04b39SOleksandr "Alex" Zinenko   if (failed(
855e4b04b39SOleksandr "Alex" Zinenko           state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
85663c9d2b1SAlex Zinenko     return DiagnosedSilenceableFailure::definiteFailure();
85763c9d2b1SAlex Zinenko 
85863c9d2b1SAlex Zinenko   for (Operation &match : block.without_terminator()) {
85963c9d2b1SAlex Zinenko     if (!isa<transform::MatchOpInterface>(match)) {
86063c9d2b1SAlex Zinenko       return emitDefiniteFailure(match.getLoc())
86163c9d2b1SAlex Zinenko              << "expected operations in the match part to "
86263c9d2b1SAlex Zinenko                 "implement MatchOpInterface";
86363c9d2b1SAlex Zinenko     }
86463c9d2b1SAlex Zinenko     DiagnosedSilenceableFailure diag =
86563c9d2b1SAlex Zinenko         state.applyTransform(cast<transform::TransformOpInterface>(match));
86663c9d2b1SAlex Zinenko     if (diag.succeeded())
86763c9d2b1SAlex Zinenko       continue;
86863c9d2b1SAlex Zinenko 
86963c9d2b1SAlex Zinenko     return diag;
87063c9d2b1SAlex Zinenko   }
87163c9d2b1SAlex Zinenko 
87263c9d2b1SAlex Zinenko   // Remember the values mapped to the terminator operands so we can
87363c9d2b1SAlex Zinenko   // forward them to the action.
87463c9d2b1SAlex Zinenko   ValueRange yieldedValues = block.getTerminator()->getOperands();
875e4b04b39SOleksandr "Alex" Zinenko   // Our contract with the caller is that the mappings will contain only the
876e4b04b39SOleksandr "Alex" Zinenko   // newly mapped values, clear the rest.
877e4b04b39SOleksandr "Alex" Zinenko   mappings.clear();
87863c9d2b1SAlex Zinenko   transform::detail::prepareValueMappings(mappings, yieldedValues, state);
87963c9d2b1SAlex Zinenko   return DiagnosedSilenceableFailure::success();
88063c9d2b1SAlex Zinenko }
88163c9d2b1SAlex Zinenko 
882633d9184SOleksandr "Alex" Zinenko /// Returns `true` if both types implement one of the interfaces provided as
883633d9184SOleksandr "Alex" Zinenko /// template parameters.
884633d9184SOleksandr "Alex" Zinenko template <typename... Tys>
885633d9184SOleksandr "Alex" Zinenko static bool implementSameInterface(Type t1, Type t2) {
886633d9184SOleksandr "Alex" Zinenko   return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
887633d9184SOleksandr "Alex" Zinenko }
888633d9184SOleksandr "Alex" Zinenko 
889633d9184SOleksandr "Alex" Zinenko /// Returns `true` if both types implement one of the transform dialect
890633d9184SOleksandr "Alex" Zinenko /// interfaces.
891633d9184SOleksandr "Alex" Zinenko static bool implementSameTransformInterface(Type t1, Type t2) {
892633d9184SOleksandr "Alex" Zinenko   return implementSameInterface<transform::TransformHandleTypeInterface,
893633d9184SOleksandr "Alex" Zinenko                                 transform::TransformParamTypeInterface,
894633d9184SOleksandr "Alex" Zinenko                                 transform::TransformValueHandleTypeInterface>(
895633d9184SOleksandr "Alex" Zinenko       t1, t2);
896633d9184SOleksandr "Alex" Zinenko }
897633d9184SOleksandr "Alex" Zinenko 
898633d9184SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
899633d9184SOleksandr "Alex" Zinenko // CollectMatchingOp
900633d9184SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
901633d9184SOleksandr "Alex" Zinenko 
902633d9184SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure
903633d9184SOleksandr "Alex" Zinenko transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
904633d9184SOleksandr "Alex" Zinenko                                     transform::TransformResults &results,
905633d9184SOleksandr "Alex" Zinenko                                     transform::TransformState &state) {
906633d9184SOleksandr "Alex" Zinenko   auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
907633d9184SOleksandr "Alex" Zinenko       getOperation(), getMatcher());
908633d9184SOleksandr "Alex" Zinenko   if (matcher.isExternal()) {
909633d9184SOleksandr "Alex" Zinenko     return emitDefiniteFailure()
910633d9184SOleksandr "Alex" Zinenko            << "unresolved external symbol " << getMatcher();
911633d9184SOleksandr "Alex" Zinenko   }
912633d9184SOleksandr "Alex" Zinenko 
913633d9184SOleksandr "Alex" Zinenko   SmallVector<SmallVector<MappedValue>, 2> rawResults;
914633d9184SOleksandr "Alex" Zinenko   rawResults.resize(getOperation()->getNumResults());
915633d9184SOleksandr "Alex" Zinenko   std::optional<DiagnosedSilenceableFailure> maybeFailure;
916633d9184SOleksandr "Alex" Zinenko   for (Operation *root : state.getPayloadOps(getRoot())) {
917633d9184SOleksandr "Alex" Zinenko     WalkResult walkResult = root->walk([&](Operation *op) {
918633d9184SOleksandr "Alex" Zinenko       DEBUG_MATCHER({
919633d9184SOleksandr "Alex" Zinenko         DBGS_MATCHER() << "matching ";
920633d9184SOleksandr "Alex" Zinenko         op->print(llvm::dbgs(),
921633d9184SOleksandr "Alex" Zinenko                   OpPrintingFlags().assumeVerified().skipRegions());
922633d9184SOleksandr "Alex" Zinenko         llvm::dbgs() << " @" << op << "\n";
923633d9184SOleksandr "Alex" Zinenko       });
924633d9184SOleksandr "Alex" Zinenko 
925633d9184SOleksandr "Alex" Zinenko       // Try matching.
926633d9184SOleksandr "Alex" Zinenko       SmallVector<SmallVector<MappedValue>> mappings;
927e4b04b39SOleksandr "Alex" Zinenko       SmallVector<transform::MappedValue> inputMapping({op});
928e4b04b39SOleksandr "Alex" Zinenko       DiagnosedSilenceableFailure diag = matchBlock(
929e4b04b39SOleksandr "Alex" Zinenko           matcher.getFunctionBody().front(),
930e4b04b39SOleksandr "Alex" Zinenko           ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
931e4b04b39SOleksandr "Alex" Zinenko           mappings);
932633d9184SOleksandr "Alex" Zinenko       if (diag.isDefiniteFailure())
933633d9184SOleksandr "Alex" Zinenko         return WalkResult::interrupt();
934633d9184SOleksandr "Alex" Zinenko       if (diag.isSilenceableFailure()) {
935633d9184SOleksandr "Alex" Zinenko         DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
936633d9184SOleksandr "Alex" Zinenko                                      << " failed: " << diag.getMessage());
937633d9184SOleksandr "Alex" Zinenko         return WalkResult::advance();
938633d9184SOleksandr "Alex" Zinenko       }
939633d9184SOleksandr "Alex" Zinenko 
940633d9184SOleksandr "Alex" Zinenko       // If succeeded, collect results.
941633d9184SOleksandr "Alex" Zinenko       for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
942633d9184SOleksandr "Alex" Zinenko         if (mapping.size() != 1) {
943633d9184SOleksandr "Alex" Zinenko           maybeFailure.emplace(emitSilenceableError()
944633d9184SOleksandr "Alex" Zinenko                                << "result #" << i << ", associated with "
945633d9184SOleksandr "Alex" Zinenko                                << mapping.size()
946633d9184SOleksandr "Alex" Zinenko                                << " payload objects, expected 1");
947633d9184SOleksandr "Alex" Zinenko           return WalkResult::interrupt();
948633d9184SOleksandr "Alex" Zinenko         }
949633d9184SOleksandr "Alex" Zinenko         rawResults[i].push_back(mapping[0]);
950633d9184SOleksandr "Alex" Zinenko       }
951633d9184SOleksandr "Alex" Zinenko       return WalkResult::advance();
952633d9184SOleksandr "Alex" Zinenko     });
953633d9184SOleksandr "Alex" Zinenko     if (walkResult.wasInterrupted())
954633d9184SOleksandr "Alex" Zinenko       return std::move(*maybeFailure);
955633d9184SOleksandr "Alex" Zinenko     assert(!maybeFailure && "failure set but the walk was not interrupted");
956633d9184SOleksandr "Alex" Zinenko 
957633d9184SOleksandr "Alex" Zinenko     for (auto &&[opResult, rawResult] :
958633d9184SOleksandr "Alex" Zinenko          llvm::zip_equal(getOperation()->getResults(), rawResults)) {
959633d9184SOleksandr "Alex" Zinenko       results.setMappedValues(opResult, rawResult);
960633d9184SOleksandr "Alex" Zinenko     }
961633d9184SOleksandr "Alex" Zinenko   }
962633d9184SOleksandr "Alex" Zinenko   return DiagnosedSilenceableFailure::success();
963633d9184SOleksandr "Alex" Zinenko }
964633d9184SOleksandr "Alex" Zinenko 
965633d9184SOleksandr "Alex" Zinenko void transform::CollectMatchingOp::getEffects(
966633d9184SOleksandr "Alex" Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
9672c1ae801Sdonald chen   onlyReadsHandle(getRootMutable(), effects);
9682c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
969633d9184SOleksandr "Alex" Zinenko   onlyReadsPayload(effects);
970633d9184SOleksandr "Alex" Zinenko }
971633d9184SOleksandr "Alex" Zinenko 
972633d9184SOleksandr "Alex" Zinenko LogicalResult transform::CollectMatchingOp::verifySymbolUses(
973633d9184SOleksandr "Alex" Zinenko     SymbolTableCollection &symbolTable) {
974633d9184SOleksandr "Alex" Zinenko   auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
975633d9184SOleksandr "Alex" Zinenko       symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
976633d9184SOleksandr "Alex" Zinenko   if (!matcherSymbol ||
977633d9184SOleksandr "Alex" Zinenko       !isa<TransformOpInterface>(matcherSymbol.getOperation()))
978633d9184SOleksandr "Alex" Zinenko     return emitError() << "unresolved matcher symbol " << getMatcher();
979633d9184SOleksandr "Alex" Zinenko 
980633d9184SOleksandr "Alex" Zinenko   ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
981633d9184SOleksandr "Alex" Zinenko   if (argumentTypes.size() != 1 ||
982633d9184SOleksandr "Alex" Zinenko       !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
983633d9184SOleksandr "Alex" Zinenko     return emitError()
984633d9184SOleksandr "Alex" Zinenko            << "expected the matcher to take one operation handle argument";
985633d9184SOleksandr "Alex" Zinenko   }
986633d9184SOleksandr "Alex" Zinenko   if (!matcherSymbol.getArgAttr(
987633d9184SOleksandr "Alex" Zinenko           0, transform::TransformDialect::kArgReadOnlyAttrName)) {
988633d9184SOleksandr "Alex" Zinenko     return emitError() << "expected the matcher argument to be marked readonly";
989633d9184SOleksandr "Alex" Zinenko   }
990633d9184SOleksandr "Alex" Zinenko 
991633d9184SOleksandr "Alex" Zinenko   ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
992633d9184SOleksandr "Alex" Zinenko   if (resultTypes.size() != getOperation()->getNumResults()) {
993633d9184SOleksandr "Alex" Zinenko     return emitError()
994633d9184SOleksandr "Alex" Zinenko            << "expected the matcher to yield as many values as op has results ("
995633d9184SOleksandr "Alex" Zinenko            << getOperation()->getNumResults() << "), got "
996633d9184SOleksandr "Alex" Zinenko            << resultTypes.size();
997633d9184SOleksandr "Alex" Zinenko   }
998633d9184SOleksandr "Alex" Zinenko 
999633d9184SOleksandr "Alex" Zinenko   for (auto &&[i, matcherType, resultType] :
1000633d9184SOleksandr "Alex" Zinenko        llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1001633d9184SOleksandr "Alex" Zinenko     if (implementSameTransformInterface(matcherType, resultType))
1002633d9184SOleksandr "Alex" Zinenko       continue;
1003633d9184SOleksandr "Alex" Zinenko 
1004633d9184SOleksandr "Alex" Zinenko     return emitError()
1005633d9184SOleksandr "Alex" Zinenko            << "mismatching type interfaces for matcher result and op result #"
1006633d9184SOleksandr "Alex" Zinenko            << i;
1007633d9184SOleksandr "Alex" Zinenko   }
1008633d9184SOleksandr "Alex" Zinenko 
1009633d9184SOleksandr "Alex" Zinenko   return success();
1010633d9184SOleksandr "Alex" Zinenko }
1011633d9184SOleksandr "Alex" Zinenko 
1012633d9184SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
1013633d9184SOleksandr "Alex" Zinenko // ForeachMatchOp
1014633d9184SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
1015633d9184SOleksandr "Alex" Zinenko 
1016e4b04b39SOleksandr "Alex" Zinenko // This is fine because nothing is actually consumed by this op.
1017e4b04b39SOleksandr "Alex" Zinenko bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1018e4b04b39SOleksandr "Alex" Zinenko 
101963c9d2b1SAlex Zinenko DiagnosedSilenceableFailure
1020c63d2b2cSMatthias Springer transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1021c63d2b2cSMatthias Springer                                  transform::TransformResults &results,
102263c9d2b1SAlex Zinenko                                  transform::TransformState &state) {
102363c9d2b1SAlex Zinenko   SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>>
102463c9d2b1SAlex Zinenko       matchActionPairs;
102563c9d2b1SAlex Zinenko   matchActionPairs.reserve(getMatchers().size());
102663c9d2b1SAlex Zinenko   SymbolTableCollection symbolTable;
102763c9d2b1SAlex Zinenko   for (auto &&[matcher, action] :
102863c9d2b1SAlex Zinenko        llvm::zip_equal(getMatchers(), getActions())) {
102963c9d2b1SAlex Zinenko     auto matcherSymbol =
103063c9d2b1SAlex Zinenko         symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
103163c9d2b1SAlex Zinenko             getOperation(), cast<SymbolRefAttr>(matcher));
103263c9d2b1SAlex Zinenko     auto actionSymbol =
103363c9d2b1SAlex Zinenko         symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
103463c9d2b1SAlex Zinenko             getOperation(), cast<SymbolRefAttr>(action));
103563c9d2b1SAlex Zinenko     assert(matcherSymbol && actionSymbol &&
103663c9d2b1SAlex Zinenko            "unresolved symbols not caught by the verifier");
103763c9d2b1SAlex Zinenko 
103863c9d2b1SAlex Zinenko     if (matcherSymbol.isExternal())
103963c9d2b1SAlex Zinenko       return emitDefiniteFailure() << "unresolved external symbol " << matcher;
104063c9d2b1SAlex Zinenko     if (actionSymbol.isExternal())
104163c9d2b1SAlex Zinenko       return emitDefiniteFailure() << "unresolved external symbol " << action;
104263c9d2b1SAlex Zinenko 
104363c9d2b1SAlex Zinenko     matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
104463c9d2b1SAlex Zinenko   }
104563c9d2b1SAlex Zinenko 
10460b790572SOleksandr "Alex" Zinenko   DiagnosedSilenceableFailure overallDiag =
10470b790572SOleksandr "Alex" Zinenko       DiagnosedSilenceableFailure::success();
1048e4b04b39SOleksandr "Alex" Zinenko 
1049e4b04b39SOleksandr "Alex" Zinenko   SmallVector<SmallVector<MappedValue>> matchInputMapping;
1050e4b04b39SOleksandr "Alex" Zinenko   SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1051e4b04b39SOleksandr "Alex" Zinenko   SmallVector<SmallVector<MappedValue>> actionResultMapping;
1052e4b04b39SOleksandr "Alex" Zinenko   // Explicitly add the mapping for the first block argument (the op being
1053e4b04b39SOleksandr "Alex" Zinenko   // matched).
1054e4b04b39SOleksandr "Alex" Zinenko   matchInputMapping.emplace_back();
1055e4b04b39SOleksandr "Alex" Zinenko   transform::detail::prepareValueMappings(matchInputMapping,
1056e4b04b39SOleksandr "Alex" Zinenko                                           getForwardedInputs(), state);
1057e4b04b39SOleksandr "Alex" Zinenko   SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1058e4b04b39SOleksandr "Alex" Zinenko   actionResultMapping.resize(getForwardedOutputs().size());
1059e4b04b39SOleksandr "Alex" Zinenko 
106063c9d2b1SAlex Zinenko   for (Operation *root : state.getPayloadOps(getRoot())) {
106163c9d2b1SAlex Zinenko     WalkResult walkResult = root->walk([&](Operation *op) {
10628483d18bSNicolas Vasilache       // If getRestrictRoot is not present, skip over the root op itself so we
10638483d18bSNicolas Vasilache       // don't invalidate it.
10648483d18bSNicolas Vasilache       if (!getRestrictRoot() && op == root)
106563c9d2b1SAlex Zinenko         return WalkResult::advance();
106663c9d2b1SAlex Zinenko 
106763c9d2b1SAlex Zinenko       DEBUG_MATCHER({
106863c9d2b1SAlex Zinenko         DBGS_MATCHER() << "matching ";
106963c9d2b1SAlex Zinenko         op->print(llvm::dbgs(),
107063c9d2b1SAlex Zinenko                   OpPrintingFlags().assumeVerified().skipRegions());
107163c9d2b1SAlex Zinenko         llvm::dbgs() << " @" << op << "\n";
107263c9d2b1SAlex Zinenko       });
107363c9d2b1SAlex Zinenko 
1074e4b04b39SOleksandr "Alex" Zinenko       firstMatchArgument.clear();
1075e4b04b39SOleksandr "Alex" Zinenko       firstMatchArgument.push_back(op);
1076e4b04b39SOleksandr "Alex" Zinenko 
107763c9d2b1SAlex Zinenko       // Try all the match/action pairs until the first successful match.
107863c9d2b1SAlex Zinenko       for (auto [matcher, action] : matchActionPairs) {
107963c9d2b1SAlex Zinenko         DiagnosedSilenceableFailure diag =
1080e4b04b39SOleksandr "Alex" Zinenko             matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1081e4b04b39SOleksandr "Alex" Zinenko                        state, matchOutputMapping);
108263c9d2b1SAlex Zinenko         if (diag.isDefiniteFailure())
108363c9d2b1SAlex Zinenko           return WalkResult::interrupt();
108463c9d2b1SAlex Zinenko         if (diag.isSilenceableFailure()) {
10853fe7127dSAlex Zinenko           DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
10863fe7127dSAlex Zinenko                                        << " failed: " << diag.getMessage());
108763c9d2b1SAlex Zinenko           continue;
108863c9d2b1SAlex Zinenko         }
108963c9d2b1SAlex Zinenko 
109022259281SMatthias Springer         auto scope = state.make_region_scope(action.getFunctionBody());
1091e4b04b39SOleksandr "Alex" Zinenko         if (failed(state.mapBlockArguments(
1092e4b04b39SOleksandr "Alex" Zinenko                 action.getFunctionBody().front().getArguments(),
1093e4b04b39SOleksandr "Alex" Zinenko                 matchOutputMapping))) {
109463c9d2b1SAlex Zinenko           return WalkResult::interrupt();
109563c9d2b1SAlex Zinenko         }
109663c9d2b1SAlex Zinenko 
109763c9d2b1SAlex Zinenko         for (Operation &transform :
109863c9d2b1SAlex Zinenko              action.getFunctionBody().front().without_terminator()) {
109963c9d2b1SAlex Zinenko           DiagnosedSilenceableFailure result =
110063c9d2b1SAlex Zinenko               state.applyTransform(cast<TransformOpInterface>(transform));
11010b790572SOleksandr "Alex" Zinenko           if (result.isDefiniteFailure())
110263c9d2b1SAlex Zinenko             return WalkResult::interrupt();
11030b790572SOleksandr "Alex" Zinenko           if (result.isSilenceableFailure()) {
11040b790572SOleksandr "Alex" Zinenko             if (overallDiag.succeeded()) {
11050b790572SOleksandr "Alex" Zinenko               overallDiag = emitSilenceableError() << "actions failed";
11060b790572SOleksandr "Alex" Zinenko             }
11070b790572SOleksandr "Alex" Zinenko             overallDiag.attachNote(action->getLoc())
11080b790572SOleksandr "Alex" Zinenko                 << "failed action: " << result.getMessage();
11090b790572SOleksandr "Alex" Zinenko             overallDiag.attachNote(op->getLoc())
11100b790572SOleksandr "Alex" Zinenko                 << "when applied to this matching payload";
11110b790572SOleksandr "Alex" Zinenko             (void)result.silence();
11120b790572SOleksandr "Alex" Zinenko             continue;
11130b790572SOleksandr "Alex" Zinenko           }
111463c9d2b1SAlex Zinenko         }
1115e4b04b39SOleksandr "Alex" Zinenko         if (failed(detail::appendValueMappings(
1116e4b04b39SOleksandr "Alex" Zinenko                 MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1117e4b04b39SOleksandr "Alex" Zinenko                 action.getFunctionBody().front().getTerminator()->getOperands(),
1118e4b04b39SOleksandr "Alex" Zinenko                 state, getFlattenResults()))) {
1119e4b04b39SOleksandr "Alex" Zinenko           emitDefiniteFailure()
1120e4b04b39SOleksandr "Alex" Zinenko               << "action @" << action.getName()
1121e4b04b39SOleksandr "Alex" Zinenko               << " has results associated with multiple payload entities, "
1122e4b04b39SOleksandr "Alex" Zinenko                  "but flattening was not requested";
1123e4b04b39SOleksandr "Alex" Zinenko           return WalkResult::interrupt();
1124e4b04b39SOleksandr "Alex" Zinenko         }
112563c9d2b1SAlex Zinenko         break;
112663c9d2b1SAlex Zinenko       }
112763c9d2b1SAlex Zinenko       return WalkResult::advance();
112863c9d2b1SAlex Zinenko     });
112963c9d2b1SAlex Zinenko     if (walkResult.wasInterrupted())
113063c9d2b1SAlex Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
113163c9d2b1SAlex Zinenko   }
113263c9d2b1SAlex Zinenko 
113363c9d2b1SAlex Zinenko   // The root operation should not have been affected, so we can just reassign
113463c9d2b1SAlex Zinenko   // the payload to the result. Note that we need to consume the root handle to
113563c9d2b1SAlex Zinenko   // make sure any handles to operations inside, that could have been affected
113663c9d2b1SAlex Zinenko   // by actions, are invalidated.
1137c1fa60b4STres Popp   results.set(llvm::cast<OpResult>(getUpdated()),
1138c1fa60b4STres Popp               state.getPayloadOps(getRoot()));
1139e4b04b39SOleksandr "Alex" Zinenko   for (auto &&[result, mapping] :
1140e4b04b39SOleksandr "Alex" Zinenko        llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1141e4b04b39SOleksandr "Alex" Zinenko     results.setMappedValues(result, mapping);
1142e4b04b39SOleksandr "Alex" Zinenko   }
11430b790572SOleksandr "Alex" Zinenko   return overallDiag;
114463c9d2b1SAlex Zinenko }
114563c9d2b1SAlex Zinenko 
1146e4b04b39SOleksandr "Alex" Zinenko void transform::ForeachMatchOp::getAsmResultNames(
1147e4b04b39SOleksandr "Alex" Zinenko     OpAsmSetValueNameFn setNameFn) {
1148e4b04b39SOleksandr "Alex" Zinenko   setNameFn(getUpdated(), "updated_root");
1149e4b04b39SOleksandr "Alex" Zinenko   for (Value v : getForwardedOutputs()) {
1150e4b04b39SOleksandr "Alex" Zinenko     setNameFn(v, "yielded");
1151e4b04b39SOleksandr "Alex" Zinenko   }
1152e4b04b39SOleksandr "Alex" Zinenko }
1153e4b04b39SOleksandr "Alex" Zinenko 
115463c9d2b1SAlex Zinenko void transform::ForeachMatchOp::getEffects(
115563c9d2b1SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
115663c9d2b1SAlex Zinenko   // Bail if invalid.
115763c9d2b1SAlex Zinenko   if (getOperation()->getNumOperands() < 1 ||
115863c9d2b1SAlex Zinenko       getOperation()->getNumResults() < 1) {
115963c9d2b1SAlex Zinenko     return modifiesPayload(effects);
116063c9d2b1SAlex Zinenko   }
116163c9d2b1SAlex Zinenko 
11622c1ae801Sdonald chen   consumesHandle(getRootMutable(), effects);
11632c1ae801Sdonald chen   onlyReadsHandle(getForwardedInputsMutable(), effects);
11642c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
116563c9d2b1SAlex Zinenko   modifiesPayload(effects);
116663c9d2b1SAlex Zinenko }
116763c9d2b1SAlex Zinenko 
116863c9d2b1SAlex Zinenko /// Parses the comma-separated list of symbol reference pairs of the format
116963c9d2b1SAlex Zinenko /// `@matcher -> @action`.
117063c9d2b1SAlex Zinenko static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
117163c9d2b1SAlex Zinenko                                             ArrayAttr &matchers,
117263c9d2b1SAlex Zinenko                                             ArrayAttr &actions) {
117363c9d2b1SAlex Zinenko   StringAttr matcher;
117463c9d2b1SAlex Zinenko   StringAttr action;
117563c9d2b1SAlex Zinenko   SmallVector<Attribute> matcherList;
117663c9d2b1SAlex Zinenko   SmallVector<Attribute> actionList;
117763c9d2b1SAlex Zinenko   do {
117863c9d2b1SAlex Zinenko     if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
117963c9d2b1SAlex Zinenko         parser.parseSymbolName(action)) {
118063c9d2b1SAlex Zinenko       return failure();
118163c9d2b1SAlex Zinenko     }
118263c9d2b1SAlex Zinenko     matcherList.push_back(SymbolRefAttr::get(matcher));
118363c9d2b1SAlex Zinenko     actionList.push_back(SymbolRefAttr::get(action));
118463c9d2b1SAlex Zinenko   } while (parser.parseOptionalComma().succeeded());
118563c9d2b1SAlex Zinenko 
118663c9d2b1SAlex Zinenko   matchers = parser.getBuilder().getArrayAttr(matcherList);
118763c9d2b1SAlex Zinenko   actions = parser.getBuilder().getArrayAttr(actionList);
118863c9d2b1SAlex Zinenko   return success();
118963c9d2b1SAlex Zinenko }
119063c9d2b1SAlex Zinenko 
119163c9d2b1SAlex Zinenko /// Prints the comma-separated list of symbol reference pairs of the format
119263c9d2b1SAlex Zinenko /// `@matcher -> @action`.
119363c9d2b1SAlex Zinenko static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
119463c9d2b1SAlex Zinenko                                      ArrayAttr matchers, ArrayAttr actions) {
119563c9d2b1SAlex Zinenko   printer.increaseIndent();
119663c9d2b1SAlex Zinenko   printer.increaseIndent();
119763c9d2b1SAlex Zinenko   for (auto &&[matcher, action, idx] : llvm::zip_equal(
119863c9d2b1SAlex Zinenko            matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
119963c9d2b1SAlex Zinenko     printer.printNewline();
120063c9d2b1SAlex Zinenko     printer << cast<SymbolRefAttr>(matcher) << " -> "
120163c9d2b1SAlex Zinenko             << cast<SymbolRefAttr>(action);
120263c9d2b1SAlex Zinenko     if (idx != matchers.size() - 1)
120363c9d2b1SAlex Zinenko       printer << ", ";
120463c9d2b1SAlex Zinenko   }
120563c9d2b1SAlex Zinenko   printer.decreaseIndent();
120663c9d2b1SAlex Zinenko   printer.decreaseIndent();
120763c9d2b1SAlex Zinenko }
120863c9d2b1SAlex Zinenko 
120963c9d2b1SAlex Zinenko LogicalResult transform::ForeachMatchOp::verify() {
121063c9d2b1SAlex Zinenko   if (getMatchers().size() != getActions().size())
121163c9d2b1SAlex Zinenko     return emitOpError() << "expected the same number of matchers and actions";
121263c9d2b1SAlex Zinenko   if (getMatchers().empty())
121363c9d2b1SAlex Zinenko     return emitOpError() << "expected at least one match/action pair";
121463c9d2b1SAlex Zinenko 
121563c9d2b1SAlex Zinenko   llvm::SmallPtrSet<Attribute, 8> matcherNames;
121663c9d2b1SAlex Zinenko   for (Attribute name : getMatchers()) {
121763c9d2b1SAlex Zinenko     if (matcherNames.insert(name).second)
121863c9d2b1SAlex Zinenko       continue;
121963c9d2b1SAlex Zinenko     emitWarning() << "matcher " << name
122063c9d2b1SAlex Zinenko                   << " is used more than once, only the first match will apply";
122163c9d2b1SAlex Zinenko   }
122263c9d2b1SAlex Zinenko 
122363c9d2b1SAlex Zinenko   return success();
122463c9d2b1SAlex Zinenko }
122563c9d2b1SAlex Zinenko 
122663c9d2b1SAlex Zinenko /// Checks that the attributes of the function-like operation have correct
122763c9d2b1SAlex Zinenko /// consumption effect annotations. If `alsoVerifyInternal`, checks for
122863c9d2b1SAlex Zinenko /// annotations being present even if they can be inferred from the body.
122963c9d2b1SAlex Zinenko static DiagnosedSilenceableFailure
1230135e5bf8SAlex Zinenko verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
123163c9d2b1SAlex Zinenko                                      bool alsoVerifyInternal = false) {
123263c9d2b1SAlex Zinenko   auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
123363c9d2b1SAlex Zinenko   llvm::SmallDenseSet<unsigned> consumedArguments;
123463c9d2b1SAlex Zinenko   if (!op.isExternal()) {
123563c9d2b1SAlex Zinenko     transform::getConsumedBlockArguments(op.getFunctionBody().front(),
123663c9d2b1SAlex Zinenko                                          consumedArguments);
123763c9d2b1SAlex Zinenko   }
123863c9d2b1SAlex Zinenko   for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
123963c9d2b1SAlex Zinenko     bool isConsumed =
124063c9d2b1SAlex Zinenko         op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
124163c9d2b1SAlex Zinenko         nullptr;
124263c9d2b1SAlex Zinenko     bool isReadOnly =
124363c9d2b1SAlex Zinenko         op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
124463c9d2b1SAlex Zinenko         nullptr;
124563c9d2b1SAlex Zinenko     if (isConsumed && isReadOnly) {
124663c9d2b1SAlex Zinenko       return transformOp.emitSilenceableError()
124763c9d2b1SAlex Zinenko              << "argument #" << i << " cannot be both readonly and consumed";
124863c9d2b1SAlex Zinenko     }
124963c9d2b1SAlex Zinenko     if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
125063c9d2b1SAlex Zinenko       return transformOp.emitSilenceableError()
125163c9d2b1SAlex Zinenko              << "must provide consumed/readonly status for arguments of "
125263c9d2b1SAlex Zinenko                 "external or called ops";
125363c9d2b1SAlex Zinenko     }
125463c9d2b1SAlex Zinenko     if (op.isExternal())
125563c9d2b1SAlex Zinenko       continue;
125663c9d2b1SAlex Zinenko 
125763c9d2b1SAlex Zinenko     if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
125863c9d2b1SAlex Zinenko       return transformOp.emitSilenceableError()
125963c9d2b1SAlex Zinenko              << "argument #" << i
126063c9d2b1SAlex Zinenko              << " is consumed in the body but is not marked as such";
126163c9d2b1SAlex Zinenko     }
1262135e5bf8SAlex Zinenko     if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1263135e5bf8SAlex Zinenko       // Cannot use op.emitWarning() here as it would attempt to verify the op
1264135e5bf8SAlex Zinenko       // before printing, resulting in infinite recursion.
1265135e5bf8SAlex Zinenko       emitWarning(op->getLoc())
1266135e5bf8SAlex Zinenko           << "op argument #" << i
126763c9d2b1SAlex Zinenko           << " is not consumed in the body but is marked as consumed";
126863c9d2b1SAlex Zinenko     }
126963c9d2b1SAlex Zinenko   }
127063c9d2b1SAlex Zinenko   return DiagnosedSilenceableFailure::success();
127163c9d2b1SAlex Zinenko }
127263c9d2b1SAlex Zinenko 
127363c9d2b1SAlex Zinenko LogicalResult transform::ForeachMatchOp::verifySymbolUses(
127463c9d2b1SAlex Zinenko     SymbolTableCollection &symbolTable) {
127563c9d2b1SAlex Zinenko   assert(getMatchers().size() == getActions().size());
127663c9d2b1SAlex Zinenko   auto consumedAttr =
127763c9d2b1SAlex Zinenko       StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
127863c9d2b1SAlex Zinenko   for (auto &&[matcher, action] :
127963c9d2b1SAlex Zinenko        llvm::zip_equal(getMatchers(), getActions())) {
1280e4b04b39SOleksandr "Alex" Zinenko     // Presence and typing.
128163c9d2b1SAlex Zinenko     auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
128263c9d2b1SAlex Zinenko         symbolTable.lookupNearestSymbolFrom(getOperation(),
128363c9d2b1SAlex Zinenko                                             cast<SymbolRefAttr>(matcher)));
128463c9d2b1SAlex Zinenko     auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
128563c9d2b1SAlex Zinenko         symbolTable.lookupNearestSymbolFrom(getOperation(),
128663c9d2b1SAlex Zinenko                                             cast<SymbolRefAttr>(action)));
128763c9d2b1SAlex Zinenko     if (!matcherSymbol ||
128863c9d2b1SAlex Zinenko         !isa<TransformOpInterface>(matcherSymbol.getOperation()))
128963c9d2b1SAlex Zinenko       return emitError() << "unresolved matcher symbol " << matcher;
129063c9d2b1SAlex Zinenko     if (!actionSymbol ||
129163c9d2b1SAlex Zinenko         !isa<TransformOpInterface>(actionSymbol.getOperation()))
129263c9d2b1SAlex Zinenko       return emitError() << "unresolved action symbol " << action;
129363c9d2b1SAlex Zinenko 
129463c9d2b1SAlex Zinenko     if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1295135e5bf8SAlex Zinenko                                                     /*emitWarnings=*/false,
129663c9d2b1SAlex Zinenko                                                     /*alsoVerifyInternal=*/true)
129763c9d2b1SAlex Zinenko                    .checkAndReport())) {
129863c9d2b1SAlex Zinenko       return failure();
129963c9d2b1SAlex Zinenko     }
130063c9d2b1SAlex Zinenko     if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
1301135e5bf8SAlex Zinenko                                                     /*emitWarnings=*/false,
130263c9d2b1SAlex Zinenko                                                     /*alsoVerifyInternal=*/true)
130363c9d2b1SAlex Zinenko                    .checkAndReport())) {
130463c9d2b1SAlex Zinenko       return failure();
130563c9d2b1SAlex Zinenko     }
130663c9d2b1SAlex Zinenko 
1307e4b04b39SOleksandr "Alex" Zinenko     // Input -> matcher forwarding.
1308e4b04b39SOleksandr "Alex" Zinenko     TypeRange operandTypes = getOperandTypes();
1309e4b04b39SOleksandr "Alex" Zinenko     TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1310e4b04b39SOleksandr "Alex" Zinenko     if (operandTypes.size() != matcherArguments.size()) {
1311e4b04b39SOleksandr "Alex" Zinenko       InFlightDiagnostic diag =
1312e4b04b39SOleksandr "Alex" Zinenko           emitError() << "the number of operands (" << operandTypes.size()
1313e4b04b39SOleksandr "Alex" Zinenko                       << ") doesn't match the number of matcher arguments ("
1314e4b04b39SOleksandr "Alex" Zinenko                       << matcherArguments.size() << ") for " << matcher;
1315e4b04b39SOleksandr "Alex" Zinenko       diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1316e4b04b39SOleksandr "Alex" Zinenko       return diag;
1317e4b04b39SOleksandr "Alex" Zinenko     }
1318e4b04b39SOleksandr "Alex" Zinenko     for (auto &&[i, operand, argument] :
1319e4b04b39SOleksandr "Alex" Zinenko          llvm::enumerate(operandTypes, matcherArguments)) {
1320e4b04b39SOleksandr "Alex" Zinenko       if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1321e4b04b39SOleksandr "Alex" Zinenko         InFlightDiagnostic diag =
1322e4b04b39SOleksandr "Alex" Zinenko             emitOpError()
1323e4b04b39SOleksandr "Alex" Zinenko             << "does not expect matcher symbol to consume its operand #" << i;
1324e4b04b39SOleksandr "Alex" Zinenko         diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1325e4b04b39SOleksandr "Alex" Zinenko         return diag;
1326e4b04b39SOleksandr "Alex" Zinenko       }
1327e4b04b39SOleksandr "Alex" Zinenko 
1328e4b04b39SOleksandr "Alex" Zinenko       if (implementSameTransformInterface(operand, argument))
1329e4b04b39SOleksandr "Alex" Zinenko         continue;
1330e4b04b39SOleksandr "Alex" Zinenko 
1331e4b04b39SOleksandr "Alex" Zinenko       InFlightDiagnostic diag =
1332e4b04b39SOleksandr "Alex" Zinenko           emitError()
1333e4b04b39SOleksandr "Alex" Zinenko           << "mismatching type interfaces for operand and matcher argument #"
1334e4b04b39SOleksandr "Alex" Zinenko           << i << " of matcher " << matcher;
1335e4b04b39SOleksandr "Alex" Zinenko       diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1336e4b04b39SOleksandr "Alex" Zinenko       return diag;
1337e4b04b39SOleksandr "Alex" Zinenko     }
1338e4b04b39SOleksandr "Alex" Zinenko 
1339e4b04b39SOleksandr "Alex" Zinenko     // Matcher -> action forwarding.
1340e4b04b39SOleksandr "Alex" Zinenko     TypeRange matcherResults = matcherSymbol.getResultTypes();
1341e4b04b39SOleksandr "Alex" Zinenko     TypeRange actionArguments = actionSymbol.getArgumentTypes();
134263c9d2b1SAlex Zinenko     if (matcherResults.size() != actionArguments.size()) {
134363c9d2b1SAlex Zinenko       return emitError() << "mismatching number of matcher results and "
134463c9d2b1SAlex Zinenko                             "action arguments between "
134563c9d2b1SAlex Zinenko                          << matcher << " (" << matcherResults.size() << ") and "
134663c9d2b1SAlex Zinenko                          << action << " (" << actionArguments.size() << ")";
134763c9d2b1SAlex Zinenko     }
134863c9d2b1SAlex Zinenko     for (auto &&[i, matcherType, actionType] :
134963c9d2b1SAlex Zinenko          llvm::enumerate(matcherResults, actionArguments)) {
135063c9d2b1SAlex Zinenko       if (implementSameTransformInterface(matcherType, actionType))
135163c9d2b1SAlex Zinenko         continue;
135263c9d2b1SAlex Zinenko 
135363c9d2b1SAlex Zinenko       return emitError() << "mismatching type interfaces for matcher result "
135463c9d2b1SAlex Zinenko                             "and action argument #"
1355e4b04b39SOleksandr "Alex" Zinenko                          << i << "of matcher " << matcher << " and action "
1356e4b04b39SOleksandr "Alex" Zinenko                          << action;
135763c9d2b1SAlex Zinenko     }
135863c9d2b1SAlex Zinenko 
1359e4b04b39SOleksandr "Alex" Zinenko     // Action -> result forwarding.
1360e4b04b39SOleksandr "Alex" Zinenko     TypeRange actionResults = actionSymbol.getResultTypes();
1361e4b04b39SOleksandr "Alex" Zinenko     auto resultTypes = TypeRange(getResultTypes()).drop_front();
1362e4b04b39SOleksandr "Alex" Zinenko     if (actionResults.size() != resultTypes.size()) {
136363c9d2b1SAlex Zinenko       InFlightDiagnostic diag =
1364e4b04b39SOleksandr "Alex" Zinenko           emitError() << "the number of action results ("
1365e4b04b39SOleksandr "Alex" Zinenko                       << actionResults.size() << ") for " << action
1366e4b04b39SOleksandr "Alex" Zinenko                       << " doesn't match the number of extra op results ("
1367e4b04b39SOleksandr "Alex" Zinenko                       << resultTypes.size() << ")";
136863c9d2b1SAlex Zinenko       diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
136963c9d2b1SAlex Zinenko       return diag;
137063c9d2b1SAlex Zinenko     }
1371e4b04b39SOleksandr "Alex" Zinenko     for (auto &&[i, resultType, actionType] :
1372e4b04b39SOleksandr "Alex" Zinenko          llvm::enumerate(resultTypes, actionResults)) {
1373e4b04b39SOleksandr "Alex" Zinenko       if (implementSameTransformInterface(resultType, actionType))
1374e4b04b39SOleksandr "Alex" Zinenko         continue;
137563c9d2b1SAlex Zinenko 
137663c9d2b1SAlex Zinenko       InFlightDiagnostic diag =
1377e4b04b39SOleksandr "Alex" Zinenko           emitError() << "mismatching type interfaces for action result #" << i
1378e4b04b39SOleksandr "Alex" Zinenko                       << " of action " << action << " and op result";
1379e4b04b39SOleksandr "Alex" Zinenko       diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
138063c9d2b1SAlex Zinenko       return diag;
138163c9d2b1SAlex Zinenko     }
138263c9d2b1SAlex Zinenko   }
138363c9d2b1SAlex Zinenko   return success();
138463c9d2b1SAlex Zinenko }
138563c9d2b1SAlex Zinenko 
138663c9d2b1SAlex Zinenko //===----------------------------------------------------------------------===//
1387bba85ebdSAlex Zinenko // ForeachOp
1388bba85ebdSAlex Zinenko //===----------------------------------------------------------------------===//
1389bba85ebdSAlex Zinenko 
1390bba85ebdSAlex Zinenko DiagnosedSilenceableFailure
1391c63d2b2cSMatthias Springer transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1392c63d2b2cSMatthias Springer                             transform::TransformResults &results,
1393bffec215SMatthias Springer                             transform::TransformState &state) {
1394d462bf68SRolf Morel   // We store the payloads before executing the body as ops may be removed from
1395d462bf68SRolf Morel   // the mapping by the TrackingRewriter while iteration is in progress.
1396d462bf68SRolf Morel   SmallVector<SmallVector<MappedValue>> payloads;
1397d462bf68SRolf Morel   detail::prepareValueMappings(payloads, getTargets(), state);
1398d462bf68SRolf Morel   size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1399d1ca1d01SGuillermo Callaghan   bool withZipShortest = getWithZipShortest();
1400a9efcbf4Smuneebkhan85 
1401a9efcbf4Smuneebkhan85   // In case of `zip_shortest`, set the number of iterations to the
1402a9efcbf4Smuneebkhan85   // smallest payload in the targets.
1403d1ca1d01SGuillermo Callaghan   if (withZipShortest) {
1404a9efcbf4Smuneebkhan85     numIterations =
1405a9efcbf4Smuneebkhan85         llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
1406a9efcbf4Smuneebkhan85                                         const SmallVector<MappedValue> &B) {
1407a9efcbf4Smuneebkhan85           return A.size() < B.size();
1408a9efcbf4Smuneebkhan85         })->size();
1409a9efcbf4Smuneebkhan85 
1410a9efcbf4Smuneebkhan85     for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1411a9efcbf4Smuneebkhan85       payloads[argIdx].resize(numIterations);
1412a9efcbf4Smuneebkhan85   }
1413d462bf68SRolf Morel 
1414d462bf68SRolf Morel   // As we will be "zipping" over them, check all payloads have the same size.
1415a9efcbf4Smuneebkhan85   // `zip_shortest` adjusts all payloads to the same size, so skip this check
1416a9efcbf4Smuneebkhan85   // when true.
1417d1ca1d01SGuillermo Callaghan   for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1418a9efcbf4Smuneebkhan85        argIdx++) {
1419d462bf68SRolf Morel     if (payloads[argIdx].size() != numIterations) {
1420d462bf68SRolf Morel       return emitSilenceableError()
1421d462bf68SRolf Morel              << "prior targets' payload size (" << numIterations
1422d462bf68SRolf Morel              << ") differs from payload size (" << payloads[argIdx].size()
1423d462bf68SRolf Morel              << ") of target " << getTargets()[argIdx];
1424d462bf68SRolf Morel     }
1425d462bf68SRolf Morel   }
1426d462bf68SRolf Morel 
1427d462bf68SRolf Morel   // Start iterating, indexing into payloads to obtain the right arguments to
1428d462bf68SRolf Morel   // call the body with - each slice of payloads at the same argument index
1429d462bf68SRolf Morel   // corresponding to a tuple to use as the body's block arguments.
1430d462bf68SRolf Morel   ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1431d462bf68SRolf Morel   SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1432d462bf68SRolf Morel   for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1433bffec215SMatthias Springer     auto scope = state.make_region_scope(getBody());
1434d462bf68SRolf Morel     // Set up arguments to the region's block.
1435d462bf68SRolf Morel     for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1436d462bf68SRolf Morel       MappedValue argument = payloads[argIdx][iterIdx];
1437d462bf68SRolf Morel       // Note that each blockArg's handle gets associated with just a single
1438d462bf68SRolf Morel       // element from the corresponding target's payload.
1439d462bf68SRolf Morel       if (failed(state.mapBlockArgument(blockArg, {argument})))
1440bba85ebdSAlex Zinenko         return DiagnosedSilenceableFailure::definiteFailure();
1441d462bf68SRolf Morel     }
1442bffec215SMatthias Springer 
1443c1e6caacSMatthias Springer     // Execute loop body.
1444bffec215SMatthias Springer     for (Operation &transform : getBody().front().without_terminator()) {
1445bffec215SMatthias Springer       DiagnosedSilenceableFailure result = state.applyTransform(
1446d462bf68SRolf Morel           llvm::cast<transform::TransformOpInterface>(transform));
1447bffec215SMatthias Springer       if (!result.succeeded())
1448bffec215SMatthias Springer         return result;
1449bffec215SMatthias Springer     }
1450c1e6caacSMatthias Springer 
1451d462bf68SRolf Morel     // Append yielded payloads to corresponding results from prior iterations.
1452d462bf68SRolf Morel     OperandRange yieldOperands = getYieldOp().getOperands();
1453d462bf68SRolf Morel     for (auto &&[result, yieldOperand, resTuple] :
1454d462bf68SRolf Morel          llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1455d462bf68SRolf Morel       // NB: each iteration we add any number of ops/vals/params to a result.
1456d462bf68SRolf Morel       if (isa<TransformHandleTypeInterface>(result.getType()))
1457d462bf68SRolf Morel         llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1458d462bf68SRolf Morel       else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1459d462bf68SRolf Morel         llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1460d462bf68SRolf Morel       else if (isa<TransformParamTypeInterface>(result.getType()))
1461d462bf68SRolf Morel         llvm::append_range(resTuple, state.getParams(yieldOperand));
1462d462bf68SRolf Morel       else
1463d462bf68SRolf Morel         assert(false && "unhandled handle type");
1464c1e6caacSMatthias Springer   }
1465c1e6caacSMatthias Springer 
1466d462bf68SRolf Morel   // Associate the accumulated result payloads to the op's actual results.
1467d462bf68SRolf Morel   for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1468d462bf68SRolf Morel     results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1469c1e6caacSMatthias Springer 
1470bffec215SMatthias Springer   return DiagnosedSilenceableFailure::success();
1471bffec215SMatthias Springer }
1472bffec215SMatthias Springer 
1473bffec215SMatthias Springer void transform::ForeachOp::getEffects(
1474bffec215SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1475d462bf68SRolf Morel   // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1476d462bf68SRolf Morel   // arity errors, this method might get called before/in absence of `verify()`.
1477d462bf68SRolf Morel   for (auto &&[target, blockArg] :
14782c1ae801Sdonald chen        llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1479d462bf68SRolf Morel     BlockArgument blockArgument = blockArg;
1480bffec215SMatthias Springer     if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1481d462bf68SRolf Morel           return isHandleConsumed(blockArgument,
1482d462bf68SRolf Morel                                   cast<TransformOpInterface>(&op));
1483bffec215SMatthias Springer         })) {
1484d462bf68SRolf Morel       consumesHandle(target, effects);
1485bffec215SMatthias Springer     } else {
1486d462bf68SRolf Morel       onlyReadsHandle(target, effects);
1487d462bf68SRolf Morel     }
1488bffec215SMatthias Springer   }
1489c1e6caacSMatthias Springer 
14904f63252dSMatthias Springer   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
14914f63252dSMatthias Springer         return doesModifyPayload(cast<TransformOpInterface>(&op));
14924f63252dSMatthias Springer       })) {
14934f63252dSMatthias Springer     modifiesPayload(effects);
14944f63252dSMatthias Springer   } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
14954f63252dSMatthias Springer                return doesReadPayload(cast<TransformOpInterface>(&op));
14964f63252dSMatthias Springer              })) {
14974f63252dSMatthias Springer     onlyReadsPayload(effects);
14984f63252dSMatthias Springer   }
14994f63252dSMatthias Springer 
15002c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
1501bffec215SMatthias Springer }
1502bffec215SMatthias Springer 
1503bffec215SMatthias Springer void transform::ForeachOp::getSuccessorRegions(
15044dd744acSMarkus Böck     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1505bffec215SMatthias Springer   Region *bodyRegion = &getBody();
15064dd744acSMarkus Böck   if (point.isParent()) {
1507bffec215SMatthias Springer     regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1508bffec215SMatthias Springer     return;
1509bffec215SMatthias Springer   }
1510bffec215SMatthias Springer 
1511bffec215SMatthias Springer   // Branch back to the region or the parent.
15124dd744acSMarkus Böck   assert(point == getBody() && "unexpected region index");
1513bffec215SMatthias Springer   regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1514bffec215SMatthias Springer   regions.emplace_back();
1515bffec215SMatthias Springer }
1516bffec215SMatthias Springer 
1517bffec215SMatthias Springer OperandRange
15184dd744acSMarkus Böck transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1519d462bf68SRolf Morel   // Each block argument handle is mapped to a subset (one op to be precise)
1520d462bf68SRolf Morel   // of the payload of the corresponding `targets` operand of ForeachOp.
15214dd744acSMarkus Böck   assert(point == getBody() && "unexpected region index");
1522bffec215SMatthias Springer   return getOperation()->getOperands();
1523bffec215SMatthias Springer }
1524bffec215SMatthias Springer 
1525c1e6caacSMatthias Springer transform::YieldOp transform::ForeachOp::getYieldOp() {
1526c1e6caacSMatthias Springer   return cast<transform::YieldOp>(getBody().front().getTerminator());
1527c1e6caacSMatthias Springer }
1528c1e6caacSMatthias Springer 
1529c1e6caacSMatthias Springer LogicalResult transform::ForeachOp::verify() {
1530d462bf68SRolf Morel   for (auto [targetOpt, bodyArgOpt] :
1531d462bf68SRolf Morel        llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1532d462bf68SRolf Morel     if (!targetOpt || !bodyArgOpt)
1533d462bf68SRolf Morel       return emitOpError() << "expects the same number of targets as the body "
1534d462bf68SRolf Morel                               "has block arguments";
1535d462bf68SRolf Morel     if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1536d462bf68SRolf Morel       return emitOpError(
1537d462bf68SRolf Morel           "expects co-indexed targets and the body's "
1538d462bf68SRolf Morel           "block arguments to have the same op/value/param type");
1539d462bf68SRolf Morel   }
1540d462bf68SRolf Morel 
1541d462bf68SRolf Morel   for (auto [resultOpt, yieldOperandOpt] :
1542d462bf68SRolf Morel        llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1543d462bf68SRolf Morel     if (!resultOpt || !yieldOperandOpt)
1544c1e6caacSMatthias Springer       return emitOpError() << "expects the same number of results as the "
1545d462bf68SRolf Morel                               "yield terminator has operands";
1546d462bf68SRolf Morel     if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1547d462bf68SRolf Morel       return emitOpError("expects co-indexed results and yield "
1548d462bf68SRolf Morel                          "operands to have the same op/value/param type");
1549d462bf68SRolf Morel   }
1550d462bf68SRolf Morel 
1551c1e6caacSMatthias Springer   return success();
1552c1e6caacSMatthias Springer }
1553c1e6caacSMatthias Springer 
1554bffec215SMatthias Springer //===----------------------------------------------------------------------===//
15554106557aSMatthias Springer // GetParentOp
1556cc6c1592SAlex Zinenko //===----------------------------------------------------------------------===//
1557cc6c1592SAlex Zinenko 
15584106557aSMatthias Springer DiagnosedSilenceableFailure
15594106557aSMatthias Springer transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
15604106557aSMatthias Springer                               transform::TransformResults &results,
15614106557aSMatthias Springer                               transform::TransformState &state) {
15624106557aSMatthias Springer   SmallVector<Operation *> parents;
15634106557aSMatthias Springer   DenseSet<Operation *> resultSet;
1564cc6c1592SAlex Zinenko   for (Operation *target : state.getPayloadOps(getTarget())) {
156504736c7fSMatthias Springer     Operation *parent = target;
156604736c7fSMatthias Springer     for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
156704736c7fSMatthias Springer       parent = parent->getParentOp();
156868033aaaSIngo Müller       while (parent) {
15694106557aSMatthias Springer         bool checkIsolatedFromAbove =
15704106557aSMatthias Springer             !getIsolatedFromAbove() ||
15714106557aSMatthias Springer             parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
15724106557aSMatthias Springer         bool checkOpName = !getOpName().has_value() ||
15734106557aSMatthias Springer                            parent->getName().getStringRef() == *getOpName();
15744106557aSMatthias Springer         if (checkIsolatedFromAbove && checkOpName)
15754106557aSMatthias Springer           break;
157668033aaaSIngo Müller         parent = parent->getParentOp();
157768033aaaSIngo Müller       }
1578cc6c1592SAlex Zinenko       if (!parent) {
157998341df0SNicolas Vasilache         if (getAllowEmptyResults()) {
158098341df0SNicolas Vasilache           results.set(llvm::cast<OpResult>(getResult()), parents);
158198341df0SNicolas Vasilache           return DiagnosedSilenceableFailure::success();
158298341df0SNicolas Vasilache         }
15831d45282aSAlex Zinenko         DiagnosedSilenceableFailure diag =
15841d45282aSAlex Zinenko             emitSilenceableError()
15854106557aSMatthias Springer             << "could not find a parent op that matches all requirements";
1586cc6c1592SAlex Zinenko         diag.attachNote(target->getLoc()) << "target op";
1587cc6c1592SAlex Zinenko         return diag;
1588cc6c1592SAlex Zinenko       }
158904736c7fSMatthias Springer     }
15904106557aSMatthias Springer     if (getDeduplicate()) {
159167e7f05aSKazu Hirata       if (resultSet.insert(parent).second)
15924106557aSMatthias Springer         parents.push_back(parent);
15934106557aSMatthias Springer     } else {
15944106557aSMatthias Springer       parents.push_back(parent);
15954106557aSMatthias Springer     }
15964106557aSMatthias Springer   }
15974106557aSMatthias Springer   results.set(llvm::cast<OpResult>(getResult()), parents);
15981d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
1599cc6c1592SAlex Zinenko }
1600cc6c1592SAlex Zinenko 
1601cc6c1592SAlex Zinenko //===----------------------------------------------------------------------===//
16024adf89fcSNicolas Vasilache // GetConsumersOfResult
16034adf89fcSNicolas Vasilache //===----------------------------------------------------------------------===//
16044adf89fcSNicolas Vasilache 
16054adf89fcSNicolas Vasilache DiagnosedSilenceableFailure
1606c63d2b2cSMatthias Springer transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1607c63d2b2cSMatthias Springer                                        transform::TransformResults &results,
16084adf89fcSNicolas Vasilache                                        transform::TransformState &state) {
16094adf89fcSNicolas Vasilache   int64_t resultNumber = getResultNumber();
16100e37ef08SMatthias Springer   auto payloadOps = state.getPayloadOps(getTarget());
16110e37ef08SMatthias Springer   if (std::empty(payloadOps)) {
16120e37ef08SMatthias Springer     results.set(cast<OpResult>(getResult()), {});
16134adf89fcSNicolas Vasilache     return DiagnosedSilenceableFailure::success();
16144adf89fcSNicolas Vasilache   }
16150e37ef08SMatthias Springer   if (!llvm::hasSingleElement(payloadOps))
16164adf89fcSNicolas Vasilache     return emitDefiniteFailure()
16174adf89fcSNicolas Vasilache            << "handle must be mapped to exactly one payload op";
16184adf89fcSNicolas Vasilache 
16190e37ef08SMatthias Springer   Operation *target = *payloadOps.begin();
16204adf89fcSNicolas Vasilache   if (target->getNumResults() <= resultNumber)
16214adf89fcSNicolas Vasilache     return emitDefiniteFailure() << "result number overflow";
1622c1fa60b4STres Popp   results.set(llvm::cast<OpResult>(getResult()),
16234adf89fcSNicolas Vasilache               llvm::to_vector(target->getResult(resultNumber).getUsers()));
16244adf89fcSNicolas Vasilache   return DiagnosedSilenceableFailure::success();
16254adf89fcSNicolas Vasilache }
16264adf89fcSNicolas Vasilache 
16274adf89fcSNicolas Vasilache //===----------------------------------------------------------------------===//
16284cf936d0SMatthias Springer // GetDefiningOp
16294cf936d0SMatthias Springer //===----------------------------------------------------------------------===//
16304cf936d0SMatthias Springer 
16314cf936d0SMatthias Springer DiagnosedSilenceableFailure
1632c63d2b2cSMatthias Springer transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1633c63d2b2cSMatthias Springer                                 transform::TransformResults &results,
16344cf936d0SMatthias Springer                                 transform::TransformState &state) {
16354cf936d0SMatthias Springer   SmallVector<Operation *> definingOps;
16364cf936d0SMatthias Springer   for (Value v : state.getPayloadValues(getTarget())) {
1637c1fa60b4STres Popp     if (llvm::isa<BlockArgument>(v)) {
16384cf936d0SMatthias Springer       DiagnosedSilenceableFailure diag =
16394cf936d0SMatthias Springer           emitSilenceableError() << "cannot get defining op of block argument";
16404cf936d0SMatthias Springer       diag.attachNote(v.getLoc()) << "target value";
16414cf936d0SMatthias Springer       return diag;
16424cf936d0SMatthias Springer     }
16434cf936d0SMatthias Springer     definingOps.push_back(v.getDefiningOp());
16444cf936d0SMatthias Springer   }
1645c1fa60b4STres Popp   results.set(llvm::cast<OpResult>(getResult()), definingOps);
16464cf936d0SMatthias Springer   return DiagnosedSilenceableFailure::success();
16474cf936d0SMatthias Springer }
16484cf936d0SMatthias Springer 
16494cf936d0SMatthias Springer //===----------------------------------------------------------------------===//
1650ecd9dc04SNicolas Vasilache // GetProducerOfOperand
1651ecd9dc04SNicolas Vasilache //===----------------------------------------------------------------------===//
1652ecd9dc04SNicolas Vasilache 
1653ecd9dc04SNicolas Vasilache DiagnosedSilenceableFailure
1654c63d2b2cSMatthias Springer transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1655c63d2b2cSMatthias Springer                                        transform::TransformResults &results,
1656ecd9dc04SNicolas Vasilache                                        transform::TransformState &state) {
1657ecd9dc04SNicolas Vasilache   int64_t operandNumber = getOperandNumber();
1658ecd9dc04SNicolas Vasilache   SmallVector<Operation *> producers;
1659ecd9dc04SNicolas Vasilache   for (Operation *target : state.getPayloadOps(getTarget())) {
1660ecd9dc04SNicolas Vasilache     Operation *producer =
1661ecd9dc04SNicolas Vasilache         target->getNumOperands() <= operandNumber
1662ecd9dc04SNicolas Vasilache             ? nullptr
1663ecd9dc04SNicolas Vasilache             : target->getOperand(operandNumber).getDefiningOp();
1664ecd9dc04SNicolas Vasilache     if (!producer) {
1665ecd9dc04SNicolas Vasilache       DiagnosedSilenceableFailure diag =
1666ecd9dc04SNicolas Vasilache           emitSilenceableError()
1667ecd9dc04SNicolas Vasilache           << "could not find a producer for operand number: " << operandNumber
1668ecd9dc04SNicolas Vasilache           << " of " << *target;
1669ecd9dc04SNicolas Vasilache       diag.attachNote(target->getLoc()) << "target op";
1670ecd9dc04SNicolas Vasilache       return diag;
1671ecd9dc04SNicolas Vasilache     }
1672ecd9dc04SNicolas Vasilache     producers.push_back(producer);
1673ecd9dc04SNicolas Vasilache   }
1674c1fa60b4STres Popp   results.set(llvm::cast<OpResult>(getResult()), producers);
1675ecd9dc04SNicolas Vasilache   return DiagnosedSilenceableFailure::success();
1676ecd9dc04SNicolas Vasilache }
1677ecd9dc04SNicolas Vasilache 
1678ecd9dc04SNicolas Vasilache //===----------------------------------------------------------------------===//
16795caab8bbSQuinn Dawkins // GetOperandOp
16805caab8bbSQuinn Dawkins //===----------------------------------------------------------------------===//
16815caab8bbSQuinn Dawkins 
16825caab8bbSQuinn Dawkins DiagnosedSilenceableFailure
16835caab8bbSQuinn Dawkins transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
16845caab8bbSQuinn Dawkins                                transform::TransformResults &results,
16855caab8bbSQuinn Dawkins                                transform::TransformState &state) {
16865caab8bbSQuinn Dawkins   SmallVector<Value> operands;
16875caab8bbSQuinn Dawkins   for (Operation *target : state.getPayloadOps(getTarget())) {
16885caab8bbSQuinn Dawkins     SmallVector<int64_t> operandPositions;
16895caab8bbSQuinn Dawkins     DiagnosedSilenceableFailure diag = expandTargetSpecification(
16905caab8bbSQuinn Dawkins         getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
16915caab8bbSQuinn Dawkins         target->getNumOperands(), operandPositions);
16925caab8bbSQuinn Dawkins     if (diag.isSilenceableFailure()) {
16935caab8bbSQuinn Dawkins       diag.attachNote(target->getLoc())
16945caab8bbSQuinn Dawkins           << "while considering positions of this payload operation";
16955caab8bbSQuinn Dawkins       return diag;
16965caab8bbSQuinn Dawkins     }
16975caab8bbSQuinn Dawkins     llvm::append_range(operands,
16985caab8bbSQuinn Dawkins                        llvm::map_range(operandPositions, [&](int64_t pos) {
16995caab8bbSQuinn Dawkins                          return target->getOperand(pos);
17005caab8bbSQuinn Dawkins                        }));
17015caab8bbSQuinn Dawkins   }
17025caab8bbSQuinn Dawkins   results.setValues(cast<OpResult>(getResult()), operands);
17035caab8bbSQuinn Dawkins   return DiagnosedSilenceableFailure::success();
17045caab8bbSQuinn Dawkins }
17055caab8bbSQuinn Dawkins 
17065caab8bbSQuinn Dawkins LogicalResult transform::GetOperandOp::verify() {
17075caab8bbSQuinn Dawkins   return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
17085caab8bbSQuinn Dawkins                                     getIsInverted(), getIsAll());
17095caab8bbSQuinn Dawkins }
17105caab8bbSQuinn Dawkins 
17115caab8bbSQuinn Dawkins //===----------------------------------------------------------------------===//
17123ef062a4SMatthias Springer // GetResultOp
17133ef062a4SMatthias Springer //===----------------------------------------------------------------------===//
17143ef062a4SMatthias Springer 
17153ef062a4SMatthias Springer DiagnosedSilenceableFailure
1716c63d2b2cSMatthias Springer transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1717c63d2b2cSMatthias Springer                               transform::TransformResults &results,
17183ef062a4SMatthias Springer                               transform::TransformState &state) {
17193ef062a4SMatthias Springer   SmallVector<Value> opResults;
17203ef062a4SMatthias Springer   for (Operation *target : state.getPayloadOps(getTarget())) {
17215caab8bbSQuinn Dawkins     SmallVector<int64_t> resultPositions;
17225caab8bbSQuinn Dawkins     DiagnosedSilenceableFailure diag = expandTargetSpecification(
17235caab8bbSQuinn Dawkins         getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
17245caab8bbSQuinn Dawkins         target->getNumResults(), resultPositions);
17255caab8bbSQuinn Dawkins     if (diag.isSilenceableFailure()) {
17265caab8bbSQuinn Dawkins       diag.attachNote(target->getLoc())
17275caab8bbSQuinn Dawkins           << "while considering positions of this payload operation";
17283ef062a4SMatthias Springer       return diag;
17293ef062a4SMatthias Springer     }
17305caab8bbSQuinn Dawkins     llvm::append_range(opResults,
17315caab8bbSQuinn Dawkins                        llvm::map_range(resultPositions, [&](int64_t pos) {
17325caab8bbSQuinn Dawkins                          return target->getResult(pos);
17335caab8bbSQuinn Dawkins                        }));
17343ef062a4SMatthias Springer   }
17355caab8bbSQuinn Dawkins   results.setValues(cast<OpResult>(getResult()), opResults);
17363ef062a4SMatthias Springer   return DiagnosedSilenceableFailure::success();
17373ef062a4SMatthias Springer }
17383ef062a4SMatthias Springer 
17395caab8bbSQuinn Dawkins LogicalResult transform::GetResultOp::verify() {
17405caab8bbSQuinn Dawkins   return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
17415caab8bbSQuinn Dawkins                                     getIsInverted(), getIsAll());
17425caab8bbSQuinn Dawkins }
17435caab8bbSQuinn Dawkins 
17443ef062a4SMatthias Springer //===----------------------------------------------------------------------===//
1745dd81c6b8SAlex Zinenko // GetTypeOp
1746dd81c6b8SAlex Zinenko //===----------------------------------------------------------------------===//
1747dd81c6b8SAlex Zinenko 
1748dd81c6b8SAlex Zinenko void transform::GetTypeOp::getEffects(
1749dd81c6b8SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
17502c1ae801Sdonald chen   onlyReadsHandle(getValueMutable(), effects);
17512c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
1752dd81c6b8SAlex Zinenko   onlyReadsPayload(effects);
1753dd81c6b8SAlex Zinenko }
1754dd81c6b8SAlex Zinenko 
1755dd81c6b8SAlex Zinenko DiagnosedSilenceableFailure
1756dd81c6b8SAlex Zinenko transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1757dd81c6b8SAlex Zinenko                             transform::TransformResults &results,
1758dd81c6b8SAlex Zinenko                             transform::TransformState &state) {
1759dd81c6b8SAlex Zinenko   SmallVector<Attribute> params;
1760085075a5SMatthias Springer   for (Value value : state.getPayloadValues(getValue())) {
1761dd81c6b8SAlex Zinenko     Type type = value.getType();
1762dd81c6b8SAlex Zinenko     if (getElemental()) {
1763dd81c6b8SAlex Zinenko       if (auto shaped = dyn_cast<ShapedType>(type)) {
1764dd81c6b8SAlex Zinenko         type = shaped.getElementType();
1765dd81c6b8SAlex Zinenko       }
1766dd81c6b8SAlex Zinenko     }
1767dd81c6b8SAlex Zinenko     params.push_back(TypeAttr::get(type));
1768dd81c6b8SAlex Zinenko   }
1769a5757c5bSChristian Sigg   results.setParams(cast<OpResult>(getResult()), params);
1770dd81c6b8SAlex Zinenko   return DiagnosedSilenceableFailure::success();
1771dd81c6b8SAlex Zinenko }
1772dd81c6b8SAlex Zinenko 
1773dd81c6b8SAlex Zinenko //===----------------------------------------------------------------------===//
1774fb409a28SAlex Zinenko // IncludeOp
1775fb409a28SAlex Zinenko //===----------------------------------------------------------------------===//
1776fb409a28SAlex Zinenko 
1777fb409a28SAlex Zinenko /// Applies the transform ops contained in `block`. Maps `results` to the same
1778fb409a28SAlex Zinenko /// values as the operands of the block terminator.
1779fb409a28SAlex Zinenko static DiagnosedSilenceableFailure
1780fb409a28SAlex Zinenko applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
1781fb409a28SAlex Zinenko                    transform::TransformState &state,
1782fb409a28SAlex Zinenko                    transform::TransformResults &results) {
1783fb409a28SAlex Zinenko   // Apply the sequenced ops one by one.
1784fb409a28SAlex Zinenko   for (Operation &transform : block.without_terminator()) {
1785fb409a28SAlex Zinenko     DiagnosedSilenceableFailure result =
1786fb409a28SAlex Zinenko         state.applyTransform(cast<transform::TransformOpInterface>(transform));
1787fb409a28SAlex Zinenko     if (result.isDefiniteFailure())
1788fb409a28SAlex Zinenko       return result;
1789fb409a28SAlex Zinenko 
1790fb409a28SAlex Zinenko     if (result.isSilenceableFailure()) {
1791fb409a28SAlex Zinenko       if (mode == transform::FailurePropagationMode::Propagate) {
1792fb409a28SAlex Zinenko         // Propagate empty results in case of early exit.
1793fb409a28SAlex Zinenko         forwardEmptyOperands(&block, state, results);
1794fb409a28SAlex Zinenko         return result;
1795fb409a28SAlex Zinenko       }
1796fb409a28SAlex Zinenko       (void)result.silence();
1797fb409a28SAlex Zinenko     }
1798fb409a28SAlex Zinenko   }
1799fb409a28SAlex Zinenko 
1800fb409a28SAlex Zinenko   // Forward the operation mapping for values yielded from the sequence to the
1801fb409a28SAlex Zinenko   // values produced by the sequence op.
180263c9d2b1SAlex Zinenko   transform::detail::forwardTerminatorOperands(&block, state, results);
1803fb409a28SAlex Zinenko   return DiagnosedSilenceableFailure::success();
1804fb409a28SAlex Zinenko }
1805fb409a28SAlex Zinenko 
1806fb409a28SAlex Zinenko DiagnosedSilenceableFailure
1807c63d2b2cSMatthias Springer transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
1808c63d2b2cSMatthias Springer                             transform::TransformResults &results,
1809fb409a28SAlex Zinenko                             transform::TransformState &state) {
1810fb409a28SAlex Zinenko   auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1811fb409a28SAlex Zinenko       getOperation(), getTarget());
1812fb409a28SAlex Zinenko   assert(callee && "unverified reference to unknown symbol");
1813fb409a28SAlex Zinenko 
181492c69468SAlex Zinenko   if (callee.isExternal())
181592c69468SAlex Zinenko     return emitDefiniteFailure() << "unresolved external named sequence";
181692c69468SAlex Zinenko 
1817fb409a28SAlex Zinenko   // Map operands to block arguments.
1818fb409a28SAlex Zinenko   SmallVector<SmallVector<MappedValue>> mappings;
1819fb409a28SAlex Zinenko   detail::prepareValueMappings(mappings, getOperands(), state);
182022259281SMatthias Springer   auto scope = state.make_region_scope(callee.getBody());
1821fb409a28SAlex Zinenko   for (auto &&[arg, map] :
1822fb409a28SAlex Zinenko        llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1823fb409a28SAlex Zinenko     if (failed(state.mapBlockArgument(arg, map)))
1824fb409a28SAlex Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
1825fb409a28SAlex Zinenko   }
1826fb409a28SAlex Zinenko 
1827fb409a28SAlex Zinenko   DiagnosedSilenceableFailure result = applySequenceBlock(
1828fb409a28SAlex Zinenko       callee.getBody().front(), getFailurePropagationMode(), state, results);
1829fb409a28SAlex Zinenko   mappings.clear();
1830fb409a28SAlex Zinenko   detail::prepareValueMappings(
1831fb409a28SAlex Zinenko       mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1832fb409a28SAlex Zinenko   for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1833fb409a28SAlex Zinenko     results.setMappedValues(result, mapping);
1834fb409a28SAlex Zinenko   return result;
1835fb409a28SAlex Zinenko }
1836fb409a28SAlex Zinenko 
1837fb409a28SAlex Zinenko static DiagnosedSilenceableFailure
1838135e5bf8SAlex Zinenko verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
1839fb409a28SAlex Zinenko 
1840fb409a28SAlex Zinenko void transform::IncludeOp::getEffects(
1841fb409a28SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
184241109341SAlex Zinenko   // Always mark as modifying the payload.
184341109341SAlex Zinenko   // TODO: a mechanism to annotate effects on payload. Even when all handles are
184441109341SAlex Zinenko   // only read, the payload may still be modified, so we currently stay on the
184541109341SAlex Zinenko   // conservative side and always indicate modification. This may prevent some
184641109341SAlex Zinenko   // code reordering.
184741109341SAlex Zinenko   modifiesPayload(effects);
184841109341SAlex Zinenko 
184941109341SAlex Zinenko   // Results are always produced.
18502c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
185141109341SAlex Zinenko 
185241109341SAlex Zinenko   // Adds default effects to operands and results. This will be added if
185341109341SAlex Zinenko   // preconditions fail so the trait verifier doesn't complain about missing
185441109341SAlex Zinenko   // effects and the real precondition failure is reported later on.
18552c1ae801Sdonald chen   auto defaultEffects = [&] {
18562c1ae801Sdonald chen     onlyReadsHandle(getOperation()->getOpOperands(), effects);
18572c1ae801Sdonald chen   };
185841109341SAlex Zinenko 
1859fb409a28SAlex Zinenko   // Bail if the callee is unknown. This may run as part of the verification
1860fb409a28SAlex Zinenko   // process before we verified the validity of the callee or of this op.
1861fb409a28SAlex Zinenko   auto target =
1862fb409a28SAlex Zinenko       getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1863fb409a28SAlex Zinenko   if (!target)
186441109341SAlex Zinenko     return defaultEffects();
1865fb409a28SAlex Zinenko   auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1866fb409a28SAlex Zinenko       getOperation(), getTarget());
1867fb409a28SAlex Zinenko   if (!callee)
186841109341SAlex Zinenko     return defaultEffects();
1869fb409a28SAlex Zinenko   DiagnosedSilenceableFailure earlyVerifierResult =
1870135e5bf8SAlex Zinenko       verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
1871fb409a28SAlex Zinenko   if (!earlyVerifierResult.succeeded()) {
1872fb409a28SAlex Zinenko     (void)earlyVerifierResult.silence();
187341109341SAlex Zinenko     return defaultEffects();
1874fb409a28SAlex Zinenko   }
1875fb409a28SAlex Zinenko 
187641109341SAlex Zinenko   for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
187741109341SAlex Zinenko     if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
18782c1ae801Sdonald chen       consumesHandle(getOperation()->getOpOperand(i), effects);
187941109341SAlex Zinenko     else
18802c1ae801Sdonald chen       onlyReadsHandle(getOperation()->getOpOperand(i), effects);
188141109341SAlex Zinenko   }
1882fb409a28SAlex Zinenko }
1883fb409a28SAlex Zinenko 
1884fb409a28SAlex Zinenko LogicalResult
1885fb409a28SAlex Zinenko transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1886fb409a28SAlex Zinenko   // Access through indirection and do additional checking because this may be
1887fb409a28SAlex Zinenko   // running before the main op verifier.
1888fb409a28SAlex Zinenko   auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
1889fb409a28SAlex Zinenko   if (!targetAttr)
1890fb409a28SAlex Zinenko     return emitOpError() << "expects a 'target' symbol reference attribute";
1891fb409a28SAlex Zinenko 
1892fb409a28SAlex Zinenko   auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
1893fb409a28SAlex Zinenko       *this, targetAttr);
1894fb409a28SAlex Zinenko   if (!target)
1895fb409a28SAlex Zinenko     return emitOpError() << "does not reference a named transform sequence";
1896fb409a28SAlex Zinenko 
1897fb409a28SAlex Zinenko   FunctionType fnType = target.getFunctionType();
1898fb409a28SAlex Zinenko   if (fnType.getNumInputs() != getNumOperands())
1899fb409a28SAlex Zinenko     return emitError("incorrect number of operands for callee");
1900fb409a28SAlex Zinenko 
1901fb409a28SAlex Zinenko   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1902fb409a28SAlex Zinenko     if (getOperand(i).getType() != fnType.getInput(i)) {
1903fb409a28SAlex Zinenko       return emitOpError("operand type mismatch: expected operand type ")
1904fb409a28SAlex Zinenko              << fnType.getInput(i) << ", but provided "
1905fb409a28SAlex Zinenko              << getOperand(i).getType() << " for operand number " << i;
1906fb409a28SAlex Zinenko     }
1907fb409a28SAlex Zinenko   }
1908fb409a28SAlex Zinenko 
1909fb409a28SAlex Zinenko   if (fnType.getNumResults() != getNumResults())
1910fb409a28SAlex Zinenko     return emitError("incorrect number of results for callee");
1911fb409a28SAlex Zinenko 
1912fb409a28SAlex Zinenko   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1913fb409a28SAlex Zinenko     Type resultType = getResult(i).getType();
1914fb409a28SAlex Zinenko     Type funcType = fnType.getResult(i);
191563c9d2b1SAlex Zinenko     if (!implementSameTransformInterface(resultType, funcType)) {
1916fb409a28SAlex Zinenko       return emitOpError() << "type of result #" << i
1917fb409a28SAlex Zinenko                            << " must implement the same transform dialect "
1918fb409a28SAlex Zinenko                               "interface as the corresponding callee result";
1919fb409a28SAlex Zinenko     }
1920fb409a28SAlex Zinenko   }
1921fb409a28SAlex Zinenko 
192263c9d2b1SAlex Zinenko   return verifyFunctionLikeConsumeAnnotations(
1923135e5bf8SAlex Zinenko              cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
192441109341SAlex Zinenko              /*alsoVerifyInternal=*/true)
192541109341SAlex Zinenko       .checkAndReport();
1926fb409a28SAlex Zinenko }
1927fb409a28SAlex Zinenko 
1928fb409a28SAlex Zinenko //===----------------------------------------------------------------------===//
192998341df0SNicolas Vasilache // MatchOperationEmptyOp
193098341df0SNicolas Vasilache //===----------------------------------------------------------------------===//
193198341df0SNicolas Vasilache 
193298341df0SNicolas Vasilache DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
193398341df0SNicolas Vasilache     ::std::optional<::mlir::Operation *> maybeCurrent,
193498341df0SNicolas Vasilache     transform::TransformResults &results, transform::TransformState &state) {
193598341df0SNicolas Vasilache   if (!maybeCurrent.has_value()) {
19368483d18bSNicolas Vasilache     DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
193798341df0SNicolas Vasilache     return DiagnosedSilenceableFailure::success();
193898341df0SNicolas Vasilache   }
19398483d18bSNicolas Vasilache   DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
194098341df0SNicolas Vasilache   return emitSilenceableError() << "operation is not empty";
194198341df0SNicolas Vasilache }
194298341df0SNicolas Vasilache 
194398341df0SNicolas Vasilache //===----------------------------------------------------------------------===//
19443fe7127dSAlex Zinenko // MatchOperationNameOp
19453fe7127dSAlex Zinenko //===----------------------------------------------------------------------===//
19463fe7127dSAlex Zinenko 
19473fe7127dSAlex Zinenko DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
19483fe7127dSAlex Zinenko     Operation *current, transform::TransformResults &results,
19493fe7127dSAlex Zinenko     transform::TransformState &state) {
19503fe7127dSAlex Zinenko   StringRef currentOpName = current->getName().getStringRef();
19513fe7127dSAlex Zinenko   for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
19523fe7127dSAlex Zinenko     if (acceptedAttr.getValue() == currentOpName)
19533fe7127dSAlex Zinenko       return DiagnosedSilenceableFailure::success();
19543fe7127dSAlex Zinenko   }
19553fe7127dSAlex Zinenko   return emitSilenceableError() << "wrong operation name";
19563fe7127dSAlex Zinenko }
19573fe7127dSAlex Zinenko 
19583fe7127dSAlex Zinenko //===----------------------------------------------------------------------===//
19593fe7127dSAlex Zinenko // MatchParamCmpIOp
19603fe7127dSAlex Zinenko //===----------------------------------------------------------------------===//
19613fe7127dSAlex Zinenko 
19623fe7127dSAlex Zinenko DiagnosedSilenceableFailure
1963c63d2b2cSMatthias Springer transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
1964c63d2b2cSMatthias Springer                                    transform::TransformResults &results,
19653fe7127dSAlex Zinenko                                    transform::TransformState &state) {
196670ebc78eSMehdi Amini   auto signedAPIntAsString = [&](const APInt &value) {
19673fe7127dSAlex Zinenko     std::string str;
19683fe7127dSAlex Zinenko     llvm::raw_string_ostream os(str);
19693fe7127dSAlex Zinenko     value.print(os, /*isSigned=*/true);
1970884221edSJOE1994     return str;
19713fe7127dSAlex Zinenko   };
19723fe7127dSAlex Zinenko 
19733fe7127dSAlex Zinenko   ArrayRef<Attribute> params = state.getParams(getParam());
19743fe7127dSAlex Zinenko   ArrayRef<Attribute> references = state.getParams(getReference());
19753fe7127dSAlex Zinenko 
19763fe7127dSAlex Zinenko   if (params.size() != references.size()) {
19773fe7127dSAlex Zinenko     return emitSilenceableError()
19783fe7127dSAlex Zinenko            << "parameters have different payload lengths (" << params.size()
19793fe7127dSAlex Zinenko            << " vs " << references.size() << ")";
19803fe7127dSAlex Zinenko   }
19813fe7127dSAlex Zinenko 
19823fe7127dSAlex Zinenko   for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
1983c1fa60b4STres Popp     auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1984c1fa60b4STres Popp     auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
19853fe7127dSAlex Zinenko     if (!intAttr || !refAttr) {
19863fe7127dSAlex Zinenko       return emitDefiniteFailure()
19873fe7127dSAlex Zinenko              << "non-integer parameter value not expected";
19883fe7127dSAlex Zinenko     }
19893fe7127dSAlex Zinenko     if (intAttr.getType() != refAttr.getType()) {
19903fe7127dSAlex Zinenko       return emitDefiniteFailure()
19913fe7127dSAlex Zinenko              << "mismatching integer attribute types in parameter #" << i;
19923fe7127dSAlex Zinenko     }
19933fe7127dSAlex Zinenko     APInt value = intAttr.getValue();
19943fe7127dSAlex Zinenko     APInt refValue = refAttr.getValue();
19953fe7127dSAlex Zinenko 
19963fe7127dSAlex Zinenko     // TODO: this copy will not be necessary in C++20.
19973fe7127dSAlex Zinenko     int64_t position = i;
19983fe7127dSAlex Zinenko     auto reportError = [&](StringRef direction) {
19993fe7127dSAlex Zinenko       DiagnosedSilenceableFailure diag =
20003fe7127dSAlex Zinenko           emitSilenceableError() << "expected parameter to be " << direction
20013fe7127dSAlex Zinenko                                  << " " << signedAPIntAsString(refValue)
20023fe7127dSAlex Zinenko                                  << ", got " << signedAPIntAsString(value);
20033fe7127dSAlex Zinenko       diag.attachNote(getParam().getLoc())
20043fe7127dSAlex Zinenko           << "value # " << position
20053fe7127dSAlex Zinenko           << " associated with the parameter defined here";
20063fe7127dSAlex Zinenko       return diag;
20073fe7127dSAlex Zinenko     };
20083fe7127dSAlex Zinenko 
20093fe7127dSAlex Zinenko     switch (getPredicate()) {
20103fe7127dSAlex Zinenko     case MatchCmpIPredicate::eq:
20113fe7127dSAlex Zinenko       if (value.eq(refValue))
20123fe7127dSAlex Zinenko         break;
20133fe7127dSAlex Zinenko       return reportError("equal to");
20143fe7127dSAlex Zinenko     case MatchCmpIPredicate::ne:
20153fe7127dSAlex Zinenko       if (value.ne(refValue))
20163fe7127dSAlex Zinenko         break;
20173fe7127dSAlex Zinenko       return reportError("not equal to");
20183fe7127dSAlex Zinenko     case MatchCmpIPredicate::lt:
20193fe7127dSAlex Zinenko       if (value.slt(refValue))
20203fe7127dSAlex Zinenko         break;
20213fe7127dSAlex Zinenko       return reportError("less than");
20223fe7127dSAlex Zinenko     case MatchCmpIPredicate::le:
20233fe7127dSAlex Zinenko       if (value.sle(refValue))
20243fe7127dSAlex Zinenko         break;
20253fe7127dSAlex Zinenko       return reportError("less than or equal to");
20263fe7127dSAlex Zinenko     case MatchCmpIPredicate::gt:
20273fe7127dSAlex Zinenko       if (value.sgt(refValue))
20283fe7127dSAlex Zinenko         break;
20293fe7127dSAlex Zinenko       return reportError("greater than");
20303fe7127dSAlex Zinenko     case MatchCmpIPredicate::ge:
20313fe7127dSAlex Zinenko       if (value.sge(refValue))
20323fe7127dSAlex Zinenko         break;
20333fe7127dSAlex Zinenko       return reportError("greater than or equal to");
20343fe7127dSAlex Zinenko     }
20353fe7127dSAlex Zinenko   }
20363fe7127dSAlex Zinenko   return DiagnosedSilenceableFailure::success();
20373fe7127dSAlex Zinenko }
20383fe7127dSAlex Zinenko 
20393fe7127dSAlex Zinenko void transform::MatchParamCmpIOp::getEffects(
20403fe7127dSAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
20412c1ae801Sdonald chen   onlyReadsHandle(getParamMutable(), effects);
20422c1ae801Sdonald chen   onlyReadsHandle(getReferenceMutable(), effects);
20433fe7127dSAlex Zinenko }
20443fe7127dSAlex Zinenko 
20453fe7127dSAlex Zinenko //===----------------------------------------------------------------------===//
20463fe7127dSAlex Zinenko // ParamConstantOp
20473fe7127dSAlex Zinenko //===----------------------------------------------------------------------===//
20483fe7127dSAlex Zinenko 
20493fe7127dSAlex Zinenko DiagnosedSilenceableFailure
2050c63d2b2cSMatthias Springer transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2051c63d2b2cSMatthias Springer                                   transform::TransformResults &results,
20523fe7127dSAlex Zinenko                                   transform::TransformState &state) {
20533fe7127dSAlex Zinenko   results.setParams(cast<OpResult>(getParam()), {getValue()});
20543fe7127dSAlex Zinenko   return DiagnosedSilenceableFailure::success();
20553fe7127dSAlex Zinenko }
20563fe7127dSAlex Zinenko 
20573fe7127dSAlex Zinenko //===----------------------------------------------------------------------===//
20588e03bfc3SAlex Zinenko // MergeHandlesOp
20598e03bfc3SAlex Zinenko //===----------------------------------------------------------------------===//
20608e03bfc3SAlex Zinenko 
20618e03bfc3SAlex Zinenko DiagnosedSilenceableFailure
2062c63d2b2cSMatthias Springer transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2063c63d2b2cSMatthias Springer                                  transform::TransformResults &results,
20648e03bfc3SAlex Zinenko                                  transform::TransformState &state) {
206519380396SQuinn Dawkins   ValueRange handles = getHandles();
206619380396SQuinn Dawkins   if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
20678e03bfc3SAlex Zinenko     SmallVector<Operation *> operations;
206819380396SQuinn Dawkins     for (Value operand : handles)
20698e03bfc3SAlex Zinenko       llvm::append_range(operations, state.getPayloadOps(operand));
20708e03bfc3SAlex Zinenko     if (!getDeduplicate()) {
2071c1fa60b4STres Popp       results.set(llvm::cast<OpResult>(getResult()), operations);
20728e03bfc3SAlex Zinenko       return DiagnosedSilenceableFailure::success();
20738e03bfc3SAlex Zinenko     }
20748e03bfc3SAlex Zinenko 
20758e03bfc3SAlex Zinenko     SetVector<Operation *> uniqued(operations.begin(), operations.end());
2076c1fa60b4STres Popp     results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
20778e03bfc3SAlex Zinenko     return DiagnosedSilenceableFailure::success();
20788e03bfc3SAlex Zinenko   }
20798e03bfc3SAlex Zinenko 
208019380396SQuinn Dawkins   if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
208119380396SQuinn Dawkins     SmallVector<Attribute> attrs;
208219380396SQuinn Dawkins     for (Value attribute : handles)
208319380396SQuinn Dawkins       llvm::append_range(attrs, state.getParams(attribute));
208419380396SQuinn Dawkins     if (!getDeduplicate()) {
208519380396SQuinn Dawkins       results.setParams(cast<OpResult>(getResult()), attrs);
208619380396SQuinn Dawkins       return DiagnosedSilenceableFailure::success();
208719380396SQuinn Dawkins     }
208819380396SQuinn Dawkins 
208919380396SQuinn Dawkins     SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
209019380396SQuinn Dawkins     results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
209119380396SQuinn Dawkins     return DiagnosedSilenceableFailure::success();
209219380396SQuinn Dawkins   }
209319380396SQuinn Dawkins 
209419380396SQuinn Dawkins   assert(
209519380396SQuinn Dawkins       llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
209619380396SQuinn Dawkins       "expected value handle type");
209719380396SQuinn Dawkins   SmallVector<Value> payloadValues;
209819380396SQuinn Dawkins   for (Value value : handles)
209919380396SQuinn Dawkins     llvm::append_range(payloadValues, state.getPayloadValues(value));
210019380396SQuinn Dawkins   if (!getDeduplicate()) {
210119380396SQuinn Dawkins     results.setValues(cast<OpResult>(getResult()), payloadValues);
210219380396SQuinn Dawkins     return DiagnosedSilenceableFailure::success();
210319380396SQuinn Dawkins   }
210419380396SQuinn Dawkins 
210519380396SQuinn Dawkins   SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
210619380396SQuinn Dawkins   results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
210719380396SQuinn Dawkins   return DiagnosedSilenceableFailure::success();
210819380396SQuinn Dawkins }
210919380396SQuinn Dawkins 
21104299be1aSAlex Zinenko bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
21114299be1aSAlex Zinenko   // Handles may be the same if deduplicating is enabled.
21124299be1aSAlex Zinenko   return getDeduplicate();
21134299be1aSAlex Zinenko }
21144299be1aSAlex Zinenko 
21158e03bfc3SAlex Zinenko void transform::MergeHandlesOp::getEffects(
21168e03bfc3SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
21172c1ae801Sdonald chen   onlyReadsHandle(getHandlesMutable(), effects);
21182c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
21198e03bfc3SAlex Zinenko 
21208e03bfc3SAlex Zinenko   // There are no effects on the Payload IR as this is only a handle
21218e03bfc3SAlex Zinenko   // manipulation.
21228e03bfc3SAlex Zinenko }
21238e03bfc3SAlex Zinenko 
21247df76121SMarkus Böck OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
21258e03bfc3SAlex Zinenko   if (getDeduplicate() || getHandles().size() != 1)
21268e03bfc3SAlex Zinenko     return {};
21278e03bfc3SAlex Zinenko 
21288e03bfc3SAlex Zinenko   // If deduplication is not required and there is only one operand, it can be
21298e03bfc3SAlex Zinenko   // used directly instead of merging.
21308e03bfc3SAlex Zinenko   return getHandles().front();
21318e03bfc3SAlex Zinenko }
21328e03bfc3SAlex Zinenko 
21338e03bfc3SAlex Zinenko //===----------------------------------------------------------------------===//
2134fb409a28SAlex Zinenko // NamedSequenceOp
2135fb409a28SAlex Zinenko //===----------------------------------------------------------------------===//
2136fb409a28SAlex Zinenko 
2137fb409a28SAlex Zinenko DiagnosedSilenceableFailure
2138c63d2b2cSMatthias Springer transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2139c63d2b2cSMatthias Springer                                   transform::TransformResults &results,
2140fb409a28SAlex Zinenko                                   transform::TransformState &state) {
21411bf08709SNicolas Vasilache   if (isExternal())
21421bf08709SNicolas Vasilache     return emitDefiniteFailure() << "unresolved external named sequence";
21431bf08709SNicolas Vasilache 
21441bf08709SNicolas Vasilache   // Map the entry block argument to the list of operations.
21451bf08709SNicolas Vasilache   // Note: this is the same implementation as PossibleTopLevelTransformOp but
21461bf08709SNicolas Vasilache   // without attaching the interface / trait since that is tailored to a
21471bf08709SNicolas Vasilache   // dangling top-level op that does not get "called".
21481bf08709SNicolas Vasilache   auto scope = state.make_region_scope(getBody());
21491bf08709SNicolas Vasilache   if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
21501bf08709SNicolas Vasilache           state, this->getOperation(), getBody())))
21511bf08709SNicolas Vasilache     return DiagnosedSilenceableFailure::definiteFailure();
21521bf08709SNicolas Vasilache 
21531bf08709SNicolas Vasilache   return applySequenceBlock(getBody().front(),
21541bf08709SNicolas Vasilache                             FailurePropagationMode::Propagate, state, results);
2155fb409a28SAlex Zinenko }
2156fb409a28SAlex Zinenko 
2157fb409a28SAlex Zinenko void transform::NamedSequenceOp::getEffects(
2158fb409a28SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2159fb409a28SAlex Zinenko 
2160fb409a28SAlex Zinenko ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
2161fb409a28SAlex Zinenko                                               OperationState &result) {
2162fb409a28SAlex Zinenko   return function_interface_impl::parseFunctionOp(
2163fb409a28SAlex Zinenko       parser, result, /*allowVariadic=*/false,
2164fb409a28SAlex Zinenko       getFunctionTypeAttrName(result.name),
2165fb409a28SAlex Zinenko       [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2166fb409a28SAlex Zinenko          function_interface_impl::VariadicFlag,
2167fb409a28SAlex Zinenko          std::string &) { return builder.getFunctionType(inputs, results); },
2168fb409a28SAlex Zinenko       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2169fb409a28SAlex Zinenko }
2170fb409a28SAlex Zinenko 
2171fb409a28SAlex Zinenko void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
2172fb409a28SAlex Zinenko   function_interface_impl::printFunctionOp(
2173fb409a28SAlex Zinenko       printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2174fb409a28SAlex Zinenko       getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2175fb409a28SAlex Zinenko       getResAttrsAttrName());
2176fb409a28SAlex Zinenko }
2177fb409a28SAlex Zinenko 
217863c9d2b1SAlex Zinenko /// Verifies that a symbol function-like transform dialect operation has the
217963c9d2b1SAlex Zinenko /// signature and the terminator that have conforming types, i.e., types
218063c9d2b1SAlex Zinenko /// implementing the same transform dialect type interface. If `allowExternal`
218163c9d2b1SAlex Zinenko /// is set, allow external symbols (declarations) and don't check the terminator
218263c9d2b1SAlex Zinenko /// as it may not exist.
218363c9d2b1SAlex Zinenko static DiagnosedSilenceableFailure
218463c9d2b1SAlex Zinenko verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
218563c9d2b1SAlex Zinenko   if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
218663c9d2b1SAlex Zinenko     DiagnosedSilenceableFailure diag =
218763c9d2b1SAlex Zinenko         emitSilenceableFailure(op)
218863c9d2b1SAlex Zinenko         << "cannot be defined inside another transform op";
218963c9d2b1SAlex Zinenko     diag.attachNote(parent.getLoc()) << "ancestor transform op";
219063c9d2b1SAlex Zinenko     return diag;
219163c9d2b1SAlex Zinenko   }
219263c9d2b1SAlex Zinenko 
219363c9d2b1SAlex Zinenko   if (op.isExternal() || op.getFunctionBody().empty()) {
219463c9d2b1SAlex Zinenko     if (allowExternal)
219563c9d2b1SAlex Zinenko       return DiagnosedSilenceableFailure::success();
219663c9d2b1SAlex Zinenko 
219763c9d2b1SAlex Zinenko     return emitSilenceableFailure(op) << "cannot be external";
219863c9d2b1SAlex Zinenko   }
219963c9d2b1SAlex Zinenko 
220063c9d2b1SAlex Zinenko   if (op.getFunctionBody().front().empty())
220163c9d2b1SAlex Zinenko     return emitSilenceableFailure(op) << "expected a non-empty body block";
220263c9d2b1SAlex Zinenko 
220363c9d2b1SAlex Zinenko   Operation *terminator = &op.getFunctionBody().front().back();
220463c9d2b1SAlex Zinenko   if (!isa<transform::YieldOp>(terminator)) {
220563c9d2b1SAlex Zinenko     DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
220663c9d2b1SAlex Zinenko                                        << "expected '"
220763c9d2b1SAlex Zinenko                                        << transform::YieldOp::getOperationName()
220863c9d2b1SAlex Zinenko                                        << "' as terminator";
220963c9d2b1SAlex Zinenko     diag.attachNote(terminator->getLoc()) << "terminator";
221063c9d2b1SAlex Zinenko     return diag;
221163c9d2b1SAlex Zinenko   }
221263c9d2b1SAlex Zinenko 
221363c9d2b1SAlex Zinenko   if (terminator->getNumOperands() != op.getResultTypes().size()) {
221463c9d2b1SAlex Zinenko     return emitSilenceableFailure(terminator)
221563c9d2b1SAlex Zinenko            << "expected terminator to have as many operands as the parent op "
221663c9d2b1SAlex Zinenko               "has results";
221763c9d2b1SAlex Zinenko   }
221863c9d2b1SAlex Zinenko   for (auto [i, operandType, resultType] : llvm::zip_equal(
221963c9d2b1SAlex Zinenko            llvm::seq<unsigned>(0, terminator->getNumOperands()),
222063c9d2b1SAlex Zinenko            terminator->getOperands().getType(), op.getResultTypes())) {
222163c9d2b1SAlex Zinenko     if (operandType == resultType)
222263c9d2b1SAlex Zinenko       continue;
222363c9d2b1SAlex Zinenko     return emitSilenceableFailure(terminator)
222463c9d2b1SAlex Zinenko            << "the type of the terminator operand #" << i
222563c9d2b1SAlex Zinenko            << " must match the type of the corresponding parent op result ("
222663c9d2b1SAlex Zinenko            << operandType << " vs " << resultType << ")";
222763c9d2b1SAlex Zinenko   }
222863c9d2b1SAlex Zinenko 
222963c9d2b1SAlex Zinenko   return DiagnosedSilenceableFailure::success();
223063c9d2b1SAlex Zinenko }
223163c9d2b1SAlex Zinenko 
2232fb409a28SAlex Zinenko /// Verification of a NamedSequenceOp. This does not report the error
2233fb409a28SAlex Zinenko /// immediately, so it can be used to check for op's well-formedness before the
2234fb409a28SAlex Zinenko /// verifier runs, e.g., during trait verification.
2235fb409a28SAlex Zinenko static DiagnosedSilenceableFailure
2236135e5bf8SAlex Zinenko verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2237fb409a28SAlex Zinenko   if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2238fb409a28SAlex Zinenko     if (!parent->getAttr(
2239fb409a28SAlex Zinenko             transform::TransformDialect::kWithNamedSequenceAttrName)) {
2240fb409a28SAlex Zinenko       DiagnosedSilenceableFailure diag =
2241fb409a28SAlex Zinenko           emitSilenceableFailure(op)
2242fb409a28SAlex Zinenko           << "expects the parent symbol table to have the '"
2243fb409a28SAlex Zinenko           << transform::TransformDialect::kWithNamedSequenceAttrName
2244fb409a28SAlex Zinenko           << "' attribute";
2245fb409a28SAlex Zinenko       diag.attachNote(parent->getLoc()) << "symbol table operation";
2246fb409a28SAlex Zinenko       return diag;
2247fb409a28SAlex Zinenko     }
2248fb409a28SAlex Zinenko   }
2249fb409a28SAlex Zinenko 
2250fb409a28SAlex Zinenko   if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2251fb409a28SAlex Zinenko     DiagnosedSilenceableFailure diag =
2252fb409a28SAlex Zinenko         emitSilenceableFailure(op)
2253fb409a28SAlex Zinenko         << "cannot be defined inside another transform op";
2254fb409a28SAlex Zinenko     diag.attachNote(parent.getLoc()) << "ancestor transform op";
2255fb409a28SAlex Zinenko     return diag;
2256fb409a28SAlex Zinenko   }
2257fb409a28SAlex Zinenko 
225892c69468SAlex Zinenko   if (op.isExternal() || op.getBody().empty())
2259135e5bf8SAlex Zinenko     return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2260135e5bf8SAlex Zinenko                                                 emitWarnings);
226192c69468SAlex Zinenko 
2262fb409a28SAlex Zinenko   if (op.getBody().front().empty())
2263fb409a28SAlex Zinenko     return emitSilenceableFailure(op) << "expected a non-empty body block";
2264fb409a28SAlex Zinenko 
2265fb409a28SAlex Zinenko   Operation *terminator = &op.getBody().front().back();
2266fb409a28SAlex Zinenko   if (!isa<transform::YieldOp>(terminator)) {
2267fb409a28SAlex Zinenko     DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
2268fb409a28SAlex Zinenko                                        << "expected '"
2269fb409a28SAlex Zinenko                                        << transform::YieldOp::getOperationName()
2270fb409a28SAlex Zinenko                                        << "' as terminator";
2271fb409a28SAlex Zinenko     diag.attachNote(terminator->getLoc()) << "terminator";
2272fb409a28SAlex Zinenko     return diag;
2273fb409a28SAlex Zinenko   }
2274fb409a28SAlex Zinenko 
2275fb409a28SAlex Zinenko   if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2276fb409a28SAlex Zinenko     return emitSilenceableFailure(terminator)
2277fb409a28SAlex Zinenko            << "expected terminator to have as many operands as the parent op "
2278fb409a28SAlex Zinenko               "has results";
2279fb409a28SAlex Zinenko   }
2280fb409a28SAlex Zinenko   for (auto [i, operandType, resultType] :
2281fb409a28SAlex Zinenko        llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2282fb409a28SAlex Zinenko                        terminator->getOperands().getType(),
2283fb409a28SAlex Zinenko                        op.getFunctionType().getResults())) {
2284fb409a28SAlex Zinenko     if (operandType == resultType)
2285fb409a28SAlex Zinenko       continue;
2286fb409a28SAlex Zinenko     return emitSilenceableFailure(terminator)
2287fb409a28SAlex Zinenko            << "the type of the terminator operand #" << i
2288fb409a28SAlex Zinenko            << " must match the type of the corresponding parent op result ("
2289fb409a28SAlex Zinenko            << operandType << " vs " << resultType << ")";
2290fb409a28SAlex Zinenko   }
2291fb409a28SAlex Zinenko 
229263c9d2b1SAlex Zinenko   auto funcOp = cast<FunctionOpInterface>(*op);
229363c9d2b1SAlex Zinenko   DiagnosedSilenceableFailure diag =
2294135e5bf8SAlex Zinenko       verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
229563c9d2b1SAlex Zinenko   if (!diag.succeeded())
229663c9d2b1SAlex Zinenko     return diag;
229763c9d2b1SAlex Zinenko 
229863c9d2b1SAlex Zinenko   return verifyYieldingSingleBlockOp(funcOp,
229963c9d2b1SAlex Zinenko                                      /*allowExternal=*/true);
2300fb409a28SAlex Zinenko }
2301fb409a28SAlex Zinenko 
2302fb409a28SAlex Zinenko LogicalResult transform::NamedSequenceOp::verify() {
2303fb409a28SAlex Zinenko   // Actual verification happens in a separate function for reusability.
2304135e5bf8SAlex Zinenko   return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2305fb409a28SAlex Zinenko }
2306fb409a28SAlex Zinenko 
23070c935894SNicolas Vasilache template <typename FnTy>
23080c935894SNicolas Vasilache static void buildSequenceBody(OpBuilder &builder, OperationState &state,
23090c935894SNicolas Vasilache                               Type bbArgType, TypeRange extraBindingTypes,
23100c935894SNicolas Vasilache                               FnTy bodyBuilder) {
23110c935894SNicolas Vasilache   SmallVector<Type> types;
23120c935894SNicolas Vasilache   types.reserve(1 + extraBindingTypes.size());
23130c935894SNicolas Vasilache   types.push_back(bbArgType);
23140c935894SNicolas Vasilache   llvm::append_range(types, extraBindingTypes);
23150c935894SNicolas Vasilache 
23160c935894SNicolas Vasilache   OpBuilder::InsertionGuard guard(builder);
23170c935894SNicolas Vasilache   Region *region = state.regions.back().get();
23180c935894SNicolas Vasilache   Block *bodyBlock =
23190c935894SNicolas Vasilache       builder.createBlock(region, region->begin(), types,
23200c935894SNicolas Vasilache                           SmallVector<Location>(types.size(), state.location));
23210c935894SNicolas Vasilache 
23220c935894SNicolas Vasilache   // Populate body.
23230c935894SNicolas Vasilache   builder.setInsertionPointToStart(bodyBlock);
23240c935894SNicolas Vasilache   if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
23250c935894SNicolas Vasilache     bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
23260c935894SNicolas Vasilache   } else {
23270c935894SNicolas Vasilache     bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
23280c935894SNicolas Vasilache                 bodyBlock->getArguments().drop_front());
23290c935894SNicolas Vasilache   }
23300c935894SNicolas Vasilache }
23310c935894SNicolas Vasilache 
23320c935894SNicolas Vasilache void transform::NamedSequenceOp::build(OpBuilder &builder,
23330c935894SNicolas Vasilache                                        OperationState &state, StringRef symName,
23340c935894SNicolas Vasilache                                        Type rootType, TypeRange resultTypes,
23350c935894SNicolas Vasilache                                        SequenceBodyBuilderFn bodyBuilder,
23360c935894SNicolas Vasilache                                        ArrayRef<NamedAttribute> attrs,
23370c935894SNicolas Vasilache                                        ArrayRef<DictionaryAttr> argAttrs) {
23380c935894SNicolas Vasilache   state.addAttribute(SymbolTable::getSymbolAttrName(),
23390c935894SNicolas Vasilache                      builder.getStringAttr(symName));
23400c935894SNicolas Vasilache   state.addAttribute(getFunctionTypeAttrName(state.name),
23418483d18bSNicolas Vasilache                      TypeAttr::get(FunctionType::get(builder.getContext(),
23428483d18bSNicolas Vasilache                                                      rootType, resultTypes)));
23430c935894SNicolas Vasilache   state.attributes.append(attrs.begin(), attrs.end());
23440c935894SNicolas Vasilache   state.addRegion();
23450c935894SNicolas Vasilache 
23460c935894SNicolas Vasilache   buildSequenceBody(builder, state, rootType,
23470c935894SNicolas Vasilache                     /*extraBindingTypes=*/TypeRange(), bodyBuilder);
23480c935894SNicolas Vasilache }
23490c935894SNicolas Vasilache 
2350fb409a28SAlex Zinenko //===----------------------------------------------------------------------===//
2351f90b6090SOleksandr "Alex" Zinenko // NumAssociationsOp
2352f90b6090SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
2353f90b6090SOleksandr "Alex" Zinenko 
2354f90b6090SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure
2355f90b6090SOleksandr "Alex" Zinenko transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2356f90b6090SOleksandr "Alex" Zinenko                                     transform::TransformResults &results,
2357f90b6090SOleksandr "Alex" Zinenko                                     transform::TransformState &state) {
2358f90b6090SOleksandr "Alex" Zinenko   size_t numAssociations =
2359f90b6090SOleksandr "Alex" Zinenko       llvm::TypeSwitch<Type, size_t>(getHandle().getType())
2360f90b6090SOleksandr "Alex" Zinenko           .Case([&](TransformHandleTypeInterface opHandle) {
2361f90b6090SOleksandr "Alex" Zinenko             return llvm::range_size(state.getPayloadOps(getHandle()));
2362f90b6090SOleksandr "Alex" Zinenko           })
2363f90b6090SOleksandr "Alex" Zinenko           .Case([&](TransformValueHandleTypeInterface valueHandle) {
2364f90b6090SOleksandr "Alex" Zinenko             return llvm::range_size(state.getPayloadValues(getHandle()));
2365f90b6090SOleksandr "Alex" Zinenko           })
2366f90b6090SOleksandr "Alex" Zinenko           .Case([&](TransformParamTypeInterface param) {
2367f90b6090SOleksandr "Alex" Zinenko             return llvm::range_size(state.getParams(getHandle()));
2368f90b6090SOleksandr "Alex" Zinenko           })
2369f90b6090SOleksandr "Alex" Zinenko           .Default([](Type) {
2370f90b6090SOleksandr "Alex" Zinenko             llvm_unreachable("unknown kind of transform dialect type");
2371f90b6090SOleksandr "Alex" Zinenko             return 0;
2372f90b6090SOleksandr "Alex" Zinenko           });
2373a5757c5bSChristian Sigg   results.setParams(cast<OpResult>(getNum()),
2374f90b6090SOleksandr "Alex" Zinenko                     rewriter.getI64IntegerAttr(numAssociations));
2375f90b6090SOleksandr "Alex" Zinenko   return DiagnosedSilenceableFailure::success();
2376f90b6090SOleksandr "Alex" Zinenko }
2377f90b6090SOleksandr "Alex" Zinenko 
2378f90b6090SOleksandr "Alex" Zinenko LogicalResult transform::NumAssociationsOp::verify() {
2379f90b6090SOleksandr "Alex" Zinenko   // Verify that the result type accepts an i64 attribute as payload.
2380a5757c5bSChristian Sigg   auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2381f90b6090SOleksandr "Alex" Zinenko   return resultType
2382f90b6090SOleksandr "Alex" Zinenko       .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2383f90b6090SOleksandr "Alex" Zinenko       .checkAndReport();
2384f90b6090SOleksandr "Alex" Zinenko }
2385f90b6090SOleksandr "Alex" Zinenko 
2386f90b6090SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
2387894fdbc7SMatthias Springer // SelectOp
2388894fdbc7SMatthias Springer //===----------------------------------------------------------------------===//
2389894fdbc7SMatthias Springer 
2390894fdbc7SMatthias Springer DiagnosedSilenceableFailure
2391894fdbc7SMatthias Springer transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2392894fdbc7SMatthias Springer                            transform::TransformResults &results,
2393894fdbc7SMatthias Springer                            transform::TransformState &state) {
2394894fdbc7SMatthias Springer   SmallVector<Operation *> result;
2395894fdbc7SMatthias Springer   auto payloadOps = state.getPayloadOps(getTarget());
2396894fdbc7SMatthias Springer   for (Operation *op : payloadOps) {
2397894fdbc7SMatthias Springer     if (op->getName().getStringRef() == getOpName())
2398894fdbc7SMatthias Springer       result.push_back(op);
2399894fdbc7SMatthias Springer   }
2400894fdbc7SMatthias Springer   results.set(cast<OpResult>(getResult()), result);
2401894fdbc7SMatthias Springer   return DiagnosedSilenceableFailure::success();
2402894fdbc7SMatthias Springer }
2403894fdbc7SMatthias Springer 
2404894fdbc7SMatthias Springer //===----------------------------------------------------------------------===//
2405288529e7SMatthias Springer // SplitHandleOp
2406af664e44SNicolas Vasilache //===----------------------------------------------------------------------===//
2407af664e44SNicolas Vasilache 
2408288529e7SMatthias Springer void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2409288529e7SMatthias Springer                                      Value target, int64_t numResultHandles) {
2410c8fab80dSNicolas Vasilache   result.addOperands(target);
241194d608d4SAlex Zinenko   result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2412c8fab80dSNicolas Vasilache }
2413c8fab80dSNicolas Vasilache 
2414af664e44SNicolas Vasilache DiagnosedSilenceableFailure
2415c63d2b2cSMatthias Springer transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2416c63d2b2cSMatthias Springer                                 transform::TransformResults &results,
2417af664e44SNicolas Vasilache                                 transform::TransformState &state) {
24181c352e66SOleksandr "Alex" Zinenko   int64_t numPayloads =
24191c352e66SOleksandr "Alex" Zinenko       llvm::TypeSwitch<Type, int64_t>(getHandle().getType())
24201c352e66SOleksandr "Alex" Zinenko           .Case<TransformHandleTypeInterface>([&](auto x) {
24211c352e66SOleksandr "Alex" Zinenko             return llvm::range_size(state.getPayloadOps(getHandle()));
24221c352e66SOleksandr "Alex" Zinenko           })
24231c352e66SOleksandr "Alex" Zinenko           .Case<TransformValueHandleTypeInterface>([&](auto x) {
24241c352e66SOleksandr "Alex" Zinenko             return llvm::range_size(state.getPayloadValues(getHandle()));
24251c352e66SOleksandr "Alex" Zinenko           })
24261c352e66SOleksandr "Alex" Zinenko           .Case<TransformParamTypeInterface>([&](auto x) {
24271c352e66SOleksandr "Alex" Zinenko             return llvm::range_size(state.getParams(getHandle()));
24281c352e66SOleksandr "Alex" Zinenko           })
24291c352e66SOleksandr "Alex" Zinenko           .Default([](auto x) {
24301c352e66SOleksandr "Alex" Zinenko             llvm_unreachable("unknown transform dialect type interface");
24311c352e66SOleksandr "Alex" Zinenko             return -1;
24321c352e66SOleksandr "Alex" Zinenko           });
24331c352e66SOleksandr "Alex" Zinenko 
2434709098fbSMatthias Springer   auto produceNumOpsError = [&]() {
2435af664e44SNicolas Vasilache     return emitSilenceableError()
2436709098fbSMatthias Springer            << getHandle() << " expected to contain " << this->getNumResults()
24371c352e66SOleksandr "Alex" Zinenko            << " payloads but it contains " << numPayloads << " payloads";
2438709098fbSMatthias Springer   };
2439288529e7SMatthias Springer 
2440709098fbSMatthias Springer   // Fail if there are more payload ops than results and no overflow result was
2441709098fbSMatthias Springer   // specified.
24421c352e66SOleksandr "Alex" Zinenko   if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2443709098fbSMatthias Springer     return produceNumOpsError();
2444709098fbSMatthias Springer 
2445709098fbSMatthias Springer   // Fail if there are more results than payload ops. Unless:
2446709098fbSMatthias Springer   // - "fail_on_payload_too_small" is set to "false", or
2447709098fbSMatthias Springer   // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
24481c352e66SOleksandr "Alex" Zinenko   if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
24491c352e66SOleksandr "Alex" Zinenko       (numPayloads != 0 || !getPassThroughEmptyHandle()))
2450709098fbSMatthias Springer     return produceNumOpsError();
2451709098fbSMatthias Springer 
24521c352e66SOleksandr "Alex" Zinenko   // Distribute payloads.
24531c352e66SOleksandr "Alex" Zinenko   SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2454709098fbSMatthias Springer   if (getOverflowResult())
24551c352e66SOleksandr "Alex" Zinenko     resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
24561c352e66SOleksandr "Alex" Zinenko 
24571c352e66SOleksandr "Alex" Zinenko   auto container = [&]() {
24581c352e66SOleksandr "Alex" Zinenko     if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
24591c352e66SOleksandr "Alex" Zinenko       return llvm::map_to_vector(
24601c352e66SOleksandr "Alex" Zinenko           state.getPayloadOps(getHandle()),
24611c352e66SOleksandr "Alex" Zinenko           [](Operation *op) -> MappedValue { return op; });
24621c352e66SOleksandr "Alex" Zinenko     }
24631c352e66SOleksandr "Alex" Zinenko     if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
24641c352e66SOleksandr "Alex" Zinenko       return llvm::map_to_vector(state.getPayloadValues(getHandle()),
24651c352e66SOleksandr "Alex" Zinenko                                  [](Value v) -> MappedValue { return v; });
24661c352e66SOleksandr "Alex" Zinenko     }
24671c352e66SOleksandr "Alex" Zinenko     assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
24681c352e66SOleksandr "Alex" Zinenko            "unsupported kind of transform dialect type");
24691c352e66SOleksandr "Alex" Zinenko     return llvm::map_to_vector(state.getParams(getHandle()),
24701c352e66SOleksandr "Alex" Zinenko                                [](Attribute a) -> MappedValue { return a; });
24711c352e66SOleksandr "Alex" Zinenko   }();
24721c352e66SOleksandr "Alex" Zinenko 
24731c352e66SOleksandr "Alex" Zinenko   for (auto &&en : llvm::enumerate(container)) {
2474709098fbSMatthias Springer     int64_t resultNum = en.index();
2475709098fbSMatthias Springer     if (resultNum >= getNumResults())
2476709098fbSMatthias Springer       resultNum = *getOverflowResult();
2477709098fbSMatthias Springer     resultHandles[resultNum].push_back(en.value());
2478709098fbSMatthias Springer   }
2479709098fbSMatthias Springer 
2480709098fbSMatthias Springer   // Set transform op results.
2481709098fbSMatthias Springer   for (auto &&it : llvm::enumerate(resultHandles))
24821c352e66SOleksandr "Alex" Zinenko     results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
24831c352e66SOleksandr "Alex" Zinenko                             it.value());
2484288529e7SMatthias Springer 
2485af664e44SNicolas Vasilache   return DiagnosedSilenceableFailure::success();
2486af664e44SNicolas Vasilache }
2487af664e44SNicolas Vasilache 
2488288529e7SMatthias Springer void transform::SplitHandleOp::getEffects(
2489af664e44SNicolas Vasilache     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
24902c1ae801Sdonald chen   onlyReadsHandle(getHandleMutable(), effects);
24912c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
2492af664e44SNicolas Vasilache   // There are no effects on the Payload IR as this is only a handle
2493af664e44SNicolas Vasilache   // manipulation.
2494af664e44SNicolas Vasilache }
2495af664e44SNicolas Vasilache 
2496709098fbSMatthias Springer LogicalResult transform::SplitHandleOp::verify() {
2497709098fbSMatthias Springer   if (getOverflowResult().has_value() &&
249880815dfbSlong.chen       !(*getOverflowResult() < getNumResults()))
2499709098fbSMatthias Springer     return emitOpError("overflow_result is not a valid result index");
25001c352e66SOleksandr "Alex" Zinenko 
25011c352e66SOleksandr "Alex" Zinenko   for (Type resultType : getResultTypes()) {
25021c352e66SOleksandr "Alex" Zinenko     if (implementSameTransformInterface(getHandle().getType(), resultType))
25031c352e66SOleksandr "Alex" Zinenko       continue;
25041c352e66SOleksandr "Alex" Zinenko 
25051c352e66SOleksandr "Alex" Zinenko     return emitOpError("expects result types to implement the same transform "
25061c352e66SOleksandr "Alex" Zinenko                        "interface as the operand type");
25071c352e66SOleksandr "Alex" Zinenko   }
25081c352e66SOleksandr "Alex" Zinenko 
2509709098fbSMatthias Springer   return success();
2510709098fbSMatthias Springer }
2511709098fbSMatthias Springer 
2512af664e44SNicolas Vasilache //===----------------------------------------------------------------------===//
251300d1a1a2SAlex Zinenko // ReplicateOp
251400d1a1a2SAlex Zinenko //===----------------------------------------------------------------------===//
251500d1a1a2SAlex Zinenko 
251600d1a1a2SAlex Zinenko DiagnosedSilenceableFailure
2517c63d2b2cSMatthias Springer transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2518c63d2b2cSMatthias Springer                               transform::TransformResults &results,
251900d1a1a2SAlex Zinenko                               transform::TransformState &state) {
25200e37ef08SMatthias Springer   unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
252100d1a1a2SAlex Zinenko   for (const auto &en : llvm::enumerate(getHandles())) {
252200d1a1a2SAlex Zinenko     Value handle = en.value();
25230e37ef08SMatthias Springer     if (isa<TransformHandleTypeInterface>(handle.getType())) {
25240e37ef08SMatthias Springer       SmallVector<Operation *> current =
25250e37ef08SMatthias Springer           llvm::to_vector(state.getPayloadOps(handle));
252600d1a1a2SAlex Zinenko       SmallVector<Operation *> payload;
252700d1a1a2SAlex Zinenko       payload.reserve(numRepetitions * current.size());
252800d1a1a2SAlex Zinenko       for (unsigned i = 0; i < numRepetitions; ++i)
252900d1a1a2SAlex Zinenko         llvm::append_range(payload, current);
2530c1fa60b4STres Popp       results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
253188c5027bSAlex Zinenko     } else {
2532c1fa60b4STres Popp       assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
253388c5027bSAlex Zinenko              "expected param type");
253488c5027bSAlex Zinenko       ArrayRef<Attribute> current = state.getParams(handle);
253588c5027bSAlex Zinenko       SmallVector<Attribute> params;
253688c5027bSAlex Zinenko       params.reserve(numRepetitions * current.size());
253788c5027bSAlex Zinenko       for (unsigned i = 0; i < numRepetitions; ++i)
253888c5027bSAlex Zinenko         llvm::append_range(params, current);
2539c1fa60b4STres Popp       results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2540c1fa60b4STres Popp                         params);
254188c5027bSAlex Zinenko     }
254200d1a1a2SAlex Zinenko   }
254300d1a1a2SAlex Zinenko   return DiagnosedSilenceableFailure::success();
254400d1a1a2SAlex Zinenko }
254500d1a1a2SAlex Zinenko 
254600d1a1a2SAlex Zinenko void transform::ReplicateOp::getEffects(
254700d1a1a2SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
25482c1ae801Sdonald chen   onlyReadsHandle(getPatternMutable(), effects);
25492c1ae801Sdonald chen   onlyReadsHandle(getHandlesMutable(), effects);
25502c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
255100d1a1a2SAlex Zinenko }
255200d1a1a2SAlex Zinenko 
255300d1a1a2SAlex Zinenko //===----------------------------------------------------------------------===//
255430f22429SAlex Zinenko // SequenceOp
255530f22429SAlex Zinenko //===----------------------------------------------------------------------===//
255630f22429SAlex Zinenko 
25571d45282aSAlex Zinenko DiagnosedSilenceableFailure
2558c63d2b2cSMatthias Springer transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2559c63d2b2cSMatthias Springer                              transform::TransformResults &results,
25600eb403adSAlex Zinenko                              transform::TransformState &state) {
25610eb403adSAlex Zinenko   // Map the entry block argument to the list of operations.
25620eb403adSAlex Zinenko   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2563bba85ebdSAlex Zinenko   if (failed(mapBlockArguments(state)))
2564bba85ebdSAlex Zinenko     return DiagnosedSilenceableFailure::definiteFailure();
25650eb403adSAlex Zinenko 
2566fb409a28SAlex Zinenko   return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2567fb409a28SAlex Zinenko                             results);
25680eb403adSAlex Zinenko }
25690eb403adSAlex Zinenko 
2570b9e40cdeSAlex Zinenko static ParseResult parseSequenceOpOperands(
257105423905SKazu Hirata     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2572b9e40cdeSAlex Zinenko     Type &rootType,
2573b9e40cdeSAlex Zinenko     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2574b9e40cdeSAlex Zinenko     SmallVectorImpl<Type> &extraBindingTypes) {
2575b9e40cdeSAlex Zinenko   OpAsmParser::UnresolvedOperand rootOperand;
2576b9e40cdeSAlex Zinenko   OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2577b9e40cdeSAlex Zinenko   if (!hasRoot.has_value()) {
2578b9e40cdeSAlex Zinenko     root = std::nullopt;
2579b9e40cdeSAlex Zinenko     return success();
2580b9e40cdeSAlex Zinenko   }
2581b9e40cdeSAlex Zinenko   if (failed(hasRoot.value()))
2582b9e40cdeSAlex Zinenko     return failure();
2583b9e40cdeSAlex Zinenko   root = rootOperand;
2584b9e40cdeSAlex Zinenko 
2585b9e40cdeSAlex Zinenko   if (succeeded(parser.parseOptionalComma())) {
2586b9e40cdeSAlex Zinenko     if (failed(parser.parseOperandList(extraBindings)))
2587b9e40cdeSAlex Zinenko       return failure();
2588b9e40cdeSAlex Zinenko   }
2589b9e40cdeSAlex Zinenko   if (failed(parser.parseColon()))
2590b9e40cdeSAlex Zinenko     return failure();
2591b9e40cdeSAlex Zinenko 
2592b9e40cdeSAlex Zinenko   // The paren is truly optional.
2593b9e40cdeSAlex Zinenko   (void)parser.parseOptionalLParen();
2594b9e40cdeSAlex Zinenko 
2595b9e40cdeSAlex Zinenko   if (failed(parser.parseType(rootType))) {
2596b9e40cdeSAlex Zinenko     return failure();
2597b9e40cdeSAlex Zinenko   }
2598b9e40cdeSAlex Zinenko 
2599b9e40cdeSAlex Zinenko   if (!extraBindings.empty()) {
2600b9e40cdeSAlex Zinenko     if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2601b9e40cdeSAlex Zinenko       return failure();
2602b9e40cdeSAlex Zinenko   }
2603b9e40cdeSAlex Zinenko 
2604b9e40cdeSAlex Zinenko   if (extraBindingTypes.size() != extraBindings.size()) {
2605b9e40cdeSAlex Zinenko     return parser.emitError(parser.getNameLoc(),
2606b9e40cdeSAlex Zinenko                             "expected types to be provided for all operands");
2607b9e40cdeSAlex Zinenko   }
2608b9e40cdeSAlex Zinenko 
2609b9e40cdeSAlex Zinenko   // The paren is truly optional.
2610b9e40cdeSAlex Zinenko   (void)parser.parseOptionalRParen();
2611b9e40cdeSAlex Zinenko   return success();
2612b9e40cdeSAlex Zinenko }
2613b9e40cdeSAlex Zinenko 
2614b9e40cdeSAlex Zinenko static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
2615b9e40cdeSAlex Zinenko                                     Value root, Type rootType,
2616b9e40cdeSAlex Zinenko                                     ValueRange extraBindings,
2617b9e40cdeSAlex Zinenko                                     TypeRange extraBindingTypes) {
2618b9e40cdeSAlex Zinenko   if (!root)
2619b9e40cdeSAlex Zinenko     return;
2620b9e40cdeSAlex Zinenko 
2621b9e40cdeSAlex Zinenko   printer << root;
2622b9e40cdeSAlex Zinenko   bool hasExtras = !extraBindings.empty();
2623b9e40cdeSAlex Zinenko   if (hasExtras) {
2624b9e40cdeSAlex Zinenko     printer << ", ";
2625b9e40cdeSAlex Zinenko     printer.printOperands(extraBindings);
2626b9e40cdeSAlex Zinenko   }
2627b9e40cdeSAlex Zinenko 
2628b9e40cdeSAlex Zinenko   printer << " : ";
2629b9e40cdeSAlex Zinenko   if (hasExtras)
2630b9e40cdeSAlex Zinenko     printer << "(";
2631b9e40cdeSAlex Zinenko 
2632b9e40cdeSAlex Zinenko   printer << rootType;
2633b9e40cdeSAlex Zinenko   if (hasExtras) {
2634b9e40cdeSAlex Zinenko     printer << ", ";
2635b9e40cdeSAlex Zinenko     llvm::interleaveComma(extraBindingTypes, printer.getStream());
2636b9e40cdeSAlex Zinenko     printer << ")";
2637b9e40cdeSAlex Zinenko   }
2638b9e40cdeSAlex Zinenko }
2639b9e40cdeSAlex Zinenko 
264040a8bd63SAlex Zinenko /// Returns `true` if the given op operand may be consuming the handle value in
264140a8bd63SAlex Zinenko /// the Transform IR. That is, if it may have a Free effect on it.
264240a8bd63SAlex Zinenko static bool isValueUsePotentialConsumer(OpOperand &use) {
264340a8bd63SAlex Zinenko   // Conservatively assume the effect being present in absence of the interface.
2644e15b855eSAlex Zinenko   auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2645e15b855eSAlex Zinenko   if (!iface)
264640a8bd63SAlex Zinenko     return true;
264740a8bd63SAlex Zinenko 
2648e15b855eSAlex Zinenko   return isHandleConsumed(use.get(), iface);
264940a8bd63SAlex Zinenko }
265040a8bd63SAlex Zinenko 
265140a8bd63SAlex Zinenko LogicalResult
265240a8bd63SAlex Zinenko checkDoubleConsume(Value value,
265340a8bd63SAlex Zinenko                    function_ref<InFlightDiagnostic()> reportError) {
265440a8bd63SAlex Zinenko   OpOperand *potentialConsumer = nullptr;
265540a8bd63SAlex Zinenko   for (OpOperand &use : value.getUses()) {
265640a8bd63SAlex Zinenko     if (!isValueUsePotentialConsumer(use))
265740a8bd63SAlex Zinenko       continue;
265840a8bd63SAlex Zinenko 
265940a8bd63SAlex Zinenko     if (!potentialConsumer) {
266040a8bd63SAlex Zinenko       potentialConsumer = &use;
266140a8bd63SAlex Zinenko       continue;
266240a8bd63SAlex Zinenko     }
266340a8bd63SAlex Zinenko 
266440a8bd63SAlex Zinenko     InFlightDiagnostic diag = reportError()
266540a8bd63SAlex Zinenko                               << " has more than one potential consumer";
266640a8bd63SAlex Zinenko     diag.attachNote(potentialConsumer->getOwner()->getLoc())
266740a8bd63SAlex Zinenko         << "used here as operand #" << potentialConsumer->getOperandNumber();
266840a8bd63SAlex Zinenko     diag.attachNote(use.getOwner()->getLoc())
266940a8bd63SAlex Zinenko         << "used here as operand #" << use.getOperandNumber();
267040a8bd63SAlex Zinenko     return diag;
267140a8bd63SAlex Zinenko   }
267240a8bd63SAlex Zinenko 
267340a8bd63SAlex Zinenko   return success();
267440a8bd63SAlex Zinenko }
267540a8bd63SAlex Zinenko 
26760eb403adSAlex Zinenko LogicalResult transform::SequenceOp::verify() {
2677b9e40cdeSAlex Zinenko   assert(getBodyBlock()->getNumArguments() >= 1 &&
2678b9e40cdeSAlex Zinenko          "the number of arguments must have been verified to be more than 1 by "
2679df969f66SAlex Zinenko          "PossibleTopLevelTransformOpTrait");
2680df969f66SAlex Zinenko 
2681b9e40cdeSAlex Zinenko   if (!getRoot() && !getExtraBindings().empty()) {
2682b9e40cdeSAlex Zinenko     return emitOpError()
2683b9e40cdeSAlex Zinenko            << "does not expect extra operands when used as top-level";
2684df969f66SAlex Zinenko   }
2685df969f66SAlex Zinenko 
2686b9e40cdeSAlex Zinenko   // Check if a block argument has more than one consuming use.
2687b9e40cdeSAlex Zinenko   for (BlockArgument arg : getBodyBlock()->getArguments()) {
2688b9e40cdeSAlex Zinenko     if (failed(checkDoubleConsume(arg, [this, arg]() {
2689b9e40cdeSAlex Zinenko           return (emitOpError() << "block argument #" << arg.getArgNumber());
2690b9e40cdeSAlex Zinenko         }))) {
269140a8bd63SAlex Zinenko       return failure();
269240a8bd63SAlex Zinenko     }
2693b9e40cdeSAlex Zinenko   }
269440a8bd63SAlex Zinenko 
269540a8bd63SAlex Zinenko   // Check properties of the nested operations they cannot check themselves.
26960eb403adSAlex Zinenko   for (Operation &child : *getBodyBlock()) {
26970eb403adSAlex Zinenko     if (!isa<TransformOpInterface>(child) &&
26980eb403adSAlex Zinenko         &child != &getBodyBlock()->back()) {
26990eb403adSAlex Zinenko       InFlightDiagnostic diag =
27000eb403adSAlex Zinenko           emitOpError()
27010eb403adSAlex Zinenko           << "expected children ops to implement TransformOpInterface";
27020eb403adSAlex Zinenko       diag.attachNote(child.getLoc()) << "op without interface";
27030eb403adSAlex Zinenko       return diag;
27040eb403adSAlex Zinenko     }
27050eb403adSAlex Zinenko 
27060eb403adSAlex Zinenko     for (OpResult result : child.getResults()) {
270740a8bd63SAlex Zinenko       auto report = [&]() {
270840a8bd63SAlex Zinenko         return (child.emitError() << "result #" << result.getResultNumber());
270940a8bd63SAlex Zinenko       };
271040a8bd63SAlex Zinenko       if (failed(checkDoubleConsume(result, report)))
271140a8bd63SAlex Zinenko         return failure();
27120eb403adSAlex Zinenko     }
27130eb403adSAlex Zinenko   }
27140eb403adSAlex Zinenko 
271502981c96Svic   if (!getBodyBlock()->mightHaveTerminator())
2716a2a1dbb5SOleksandr "Alex" Zinenko     return emitOpError() << "expects to have a terminator in the body";
2717a2a1dbb5SOleksandr "Alex" Zinenko 
27180eb403adSAlex Zinenko   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
27190eb403adSAlex Zinenko       getOperation()->getResultTypes()) {
27200eb403adSAlex Zinenko     InFlightDiagnostic diag = emitOpError()
27210eb403adSAlex Zinenko                               << "expects the types of the terminator operands "
27220eb403adSAlex Zinenko                                  "to match the types of the result";
27230eb403adSAlex Zinenko     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
27240eb403adSAlex Zinenko     return diag;
27250eb403adSAlex Zinenko   }
27260eb403adSAlex Zinenko   return success();
27270eb403adSAlex Zinenko }
272830f22429SAlex Zinenko 
27290242b962SAlex Zinenko void transform::SequenceOp::getEffects(
27300242b962SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
273194d608d4SAlex Zinenko   getPotentialTopLevelEffects(effects);
273240a8bd63SAlex Zinenko }
273340a8bd63SAlex Zinenko 
27344dd744acSMarkus Böck OperandRange
27354dd744acSMarkus Böck transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
27364dd744acSMarkus Böck   assert(point == getBody() && "unexpected region index");
2737b9e40cdeSAlex Zinenko   if (getOperation()->getNumOperands() > 0)
273873c3dff1SAlex Zinenko     return getOperation()->getOperands();
273973c3dff1SAlex Zinenko   return OperandRange(getOperation()->operand_end(),
274073c3dff1SAlex Zinenko                       getOperation()->operand_end());
274173c3dff1SAlex Zinenko }
274273c3dff1SAlex Zinenko 
274373c3dff1SAlex Zinenko void transform::SequenceOp::getSuccessorRegions(
27444dd744acSMarkus Böck     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
27454dd744acSMarkus Böck   if (point.isParent()) {
274673c3dff1SAlex Zinenko     Region *bodyRegion = &getBody();
2747138df298SMarkus Böck     regions.emplace_back(bodyRegion, getNumOperands() != 0
274873c3dff1SAlex Zinenko                                          ? bodyRegion->getArguments()
274973c3dff1SAlex Zinenko                                          : Block::BlockArgListType());
275073c3dff1SAlex Zinenko     return;
275173c3dff1SAlex Zinenko   }
275273c3dff1SAlex Zinenko 
27534dd744acSMarkus Böck   assert(point == getBody() && "unexpected region index");
275473c3dff1SAlex Zinenko   regions.emplace_back(getOperation()->getResults());
275573c3dff1SAlex Zinenko }
275673c3dff1SAlex Zinenko 
275773c3dff1SAlex Zinenko void transform::SequenceOp::getRegionInvocationBounds(
275873c3dff1SAlex Zinenko     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
275973c3dff1SAlex Zinenko   (void)operands;
276073c3dff1SAlex Zinenko   bounds.emplace_back(1, 1);
276173c3dff1SAlex Zinenko }
276273c3dff1SAlex Zinenko 
276300c95b19SMatthias Springer void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
276400c95b19SMatthias Springer                                   TypeRange resultTypes,
276500c95b19SMatthias Springer                                   FailurePropagationMode failurePropagationMode,
276600c95b19SMatthias Springer                                   Value root,
276700c95b19SMatthias Springer                                   SequenceBodyBuilderFn bodyBuilder) {
2768b9e40cdeSAlex Zinenko   build(builder, state, resultTypes, failurePropagationMode, root,
276901b9d355SAdrian Kuegel         /*extra_bindings=*/ValueRange());
2770df969f66SAlex Zinenko   Type bbArgType = root.getType();
2771b9e40cdeSAlex Zinenko   buildSequenceBody(builder, state, bbArgType,
2772b9e40cdeSAlex Zinenko                     /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2773b9e40cdeSAlex Zinenko }
277400c95b19SMatthias Springer 
2775b9e40cdeSAlex Zinenko void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2776b9e40cdeSAlex Zinenko                                   TypeRange resultTypes,
2777b9e40cdeSAlex Zinenko                                   FailurePropagationMode failurePropagationMode,
2778b9e40cdeSAlex Zinenko                                   Value root, ValueRange extraBindings,
2779b9e40cdeSAlex Zinenko                                   SequenceBodyBuilderArgsFn bodyBuilder) {
2780b9e40cdeSAlex Zinenko   build(builder, state, resultTypes, failurePropagationMode, root,
2781b9e40cdeSAlex Zinenko         extraBindings);
2782b9e40cdeSAlex Zinenko   buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
2783b9e40cdeSAlex Zinenko                     bodyBuilder);
278400c95b19SMatthias Springer }
278500c95b19SMatthias Springer 
278600c95b19SMatthias Springer void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
278700c95b19SMatthias Springer                                   TypeRange resultTypes,
278800c95b19SMatthias Springer                                   FailurePropagationMode failurePropagationMode,
278900c95b19SMatthias Springer                                   Type bbArgType,
279000c95b19SMatthias Springer                                   SequenceBodyBuilderFn bodyBuilder) {
2791b9e40cdeSAlex Zinenko   build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
279201b9d355SAdrian Kuegel         /*extra_bindings=*/ValueRange());
2793b9e40cdeSAlex Zinenko   buildSequenceBody(builder, state, bbArgType,
2794b9e40cdeSAlex Zinenko                     /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2795b9e40cdeSAlex Zinenko }
279600c95b19SMatthias Springer 
2797b9e40cdeSAlex Zinenko void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2798b9e40cdeSAlex Zinenko                                   TypeRange resultTypes,
2799b9e40cdeSAlex Zinenko                                   FailurePropagationMode failurePropagationMode,
2800b9e40cdeSAlex Zinenko                                   Type bbArgType, TypeRange extraBindingTypes,
2801b9e40cdeSAlex Zinenko                                   SequenceBodyBuilderArgsFn bodyBuilder) {
2802b9e40cdeSAlex Zinenko   build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
280301b9d355SAdrian Kuegel         /*extra_bindings=*/ValueRange());
2804b9e40cdeSAlex Zinenko   buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
280500c95b19SMatthias Springer }
280600c95b19SMatthias Springer 
280730f22429SAlex Zinenko //===----------------------------------------------------------------------===//
28084b428364SMatthias Springer // PrintOp
28094b428364SMatthias Springer //===----------------------------------------------------------------------===//
28104b428364SMatthias Springer 
2811c8fab80dSNicolas Vasilache void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2812c8fab80dSNicolas Vasilache                                StringRef name) {
2813214ce4daSJinyun (Joey) Ye   if (!name.empty())
2814214ce4daSJinyun (Joey) Ye     result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
2815c8fab80dSNicolas Vasilache }
2816c8fab80dSNicolas Vasilache 
2817c8fab80dSNicolas Vasilache void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2818c8fab80dSNicolas Vasilache                                Value target, StringRef name) {
2819c8fab80dSNicolas Vasilache   result.addOperands({target});
2820c8fab80dSNicolas Vasilache   build(builder, result, name);
2821c8fab80dSNicolas Vasilache }
2822c8fab80dSNicolas Vasilache 
28234b428364SMatthias Springer DiagnosedSilenceableFailure
2824c63d2b2cSMatthias Springer transform::PrintOp::apply(transform::TransformRewriter &rewriter,
2825c63d2b2cSMatthias Springer                           transform::TransformResults &results,
28264b428364SMatthias Springer                           transform::TransformState &state) {
282706ca5c81SNicolas Vasilache   llvm::outs() << "[[[ IR printer: ";
28284b428364SMatthias Springer   if (getName().has_value())
282906ca5c81SNicolas Vasilache     llvm::outs() << *getName() << " ";
28304b428364SMatthias Springer 
2831a8cfa7cbSJakub Kuderski   OpPrintingFlags printFlags;
2832a8cfa7cbSJakub Kuderski   if (getAssumeVerified().value_or(false))
2833a8cfa7cbSJakub Kuderski     printFlags.assumeVerified();
2834a8cfa7cbSJakub Kuderski   if (getUseLocalScope().value_or(false))
2835a8cfa7cbSJakub Kuderski     printFlags.useLocalScope();
2836a8cfa7cbSJakub Kuderski   if (getSkipRegions().value_or(false))
2837a8cfa7cbSJakub Kuderski     printFlags.skipRegions();
2838a8cfa7cbSJakub Kuderski 
28394b428364SMatthias Springer   if (!getTarget()) {
2840a8cfa7cbSJakub Kuderski     llvm::outs() << "top-level ]]]\n";
2841a8cfa7cbSJakub Kuderski     state.getTopLevel()->print(llvm::outs(), printFlags);
2842a8cfa7cbSJakub Kuderski     llvm::outs() << "\n";
2843*f6bfbc87SOleksandr "Alex" Zinenko     llvm::outs().flush();
28444b428364SMatthias Springer     return DiagnosedSilenceableFailure::success();
28454b428364SMatthias Springer   }
28464b428364SMatthias Springer 
284706ca5c81SNicolas Vasilache   llvm::outs() << "]]]\n";
2848a8cfa7cbSJakub Kuderski   for (Operation *target : state.getPayloadOps(getTarget())) {
2849a8cfa7cbSJakub Kuderski     target->print(llvm::outs(), printFlags);
2850a8cfa7cbSJakub Kuderski     llvm::outs() << "\n";
2851a8cfa7cbSJakub Kuderski   }
28524b428364SMatthias Springer 
2853*f6bfbc87SOleksandr "Alex" Zinenko   llvm::outs().flush();
28544b428364SMatthias Springer   return DiagnosedSilenceableFailure::success();
28554b428364SMatthias Springer }
28564b428364SMatthias Springer 
28574b428364SMatthias Springer void transform::PrintOp::getEffects(
28584b428364SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2859fa1b807bSOleksandr "Alex" Zinenko   // We don't really care about mutability here, but `getTarget` now
2860fa1b807bSOleksandr "Alex" Zinenko   // unconditionally casts to a specific type before verification could run
2861fa1b807bSOleksandr "Alex" Zinenko   // here.
2862fa1b807bSOleksandr "Alex" Zinenko   if (!getTargetMutable().empty())
28632c1ae801Sdonald chen     onlyReadsHandle(getTargetMutable()[0], effects);
28644b428364SMatthias Springer   onlyReadsPayload(effects);
28654b428364SMatthias Springer 
28664b428364SMatthias Springer   // There is no resource for stderr file descriptor, so just declare print
28674b428364SMatthias Springer   // writes into the default resource.
28684b428364SMatthias Springer   effects.emplace_back(MemoryEffects::Write::get());
28694b428364SMatthias Springer }
28700242b962SAlex Zinenko 
28710242b962SAlex Zinenko //===----------------------------------------------------------------------===//
28727dfcd4b7SMatthias Springer // VerifyOp
28737dfcd4b7SMatthias Springer //===----------------------------------------------------------------------===//
28747dfcd4b7SMatthias Springer 
28757dfcd4b7SMatthias Springer DiagnosedSilenceableFailure
28767dfcd4b7SMatthias Springer transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
28777dfcd4b7SMatthias Springer                                 Operation *target,
28787dfcd4b7SMatthias Springer                                 transform::ApplyToEachResultList &results,
28797dfcd4b7SMatthias Springer                                 transform::TransformState &state) {
28807dfcd4b7SMatthias Springer   if (failed(::mlir::verify(target))) {
28817dfcd4b7SMatthias Springer     DiagnosedDefiniteFailure diag = emitDefiniteFailure()
28827dfcd4b7SMatthias Springer                                     << "failed to verify payload op";
28837dfcd4b7SMatthias Springer     diag.attachNote(target->getLoc()) << "payload op";
28847dfcd4b7SMatthias Springer     return diag;
28857dfcd4b7SMatthias Springer   }
28867dfcd4b7SMatthias Springer   return DiagnosedSilenceableFailure::success();
28877dfcd4b7SMatthias Springer }
28887dfcd4b7SMatthias Springer 
28897dfcd4b7SMatthias Springer void transform::VerifyOp::getEffects(
28907dfcd4b7SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
28912c1ae801Sdonald chen   transform::onlyReadsHandle(getTargetMutable(), effects);
28927dfcd4b7SMatthias Springer }
28937dfcd4b7SMatthias Springer 
28947dfcd4b7SMatthias Springer //===----------------------------------------------------------------------===//
28950242b962SAlex Zinenko // YieldOp
28960242b962SAlex Zinenko //===----------------------------------------------------------------------===//
28970242b962SAlex Zinenko 
28980242b962SAlex Zinenko void transform::YieldOp::getEffects(
28990242b962SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
29002c1ae801Sdonald chen   onlyReadsHandle(getOperandsMutable(), effects);
29010242b962SAlex Zinenko }
2902