1c63d2b2cSMatthias Springer //===- TransformOps.cpp - Transform dialect operations --------------------===// 20eb403adSAlex Zinenko // 30eb403adSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40eb403adSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 50eb403adSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60eb403adSAlex Zinenko // 70eb403adSAlex Zinenko //===----------------------------------------------------------------------===// 80eb403adSAlex Zinenko 90eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformOps.h" 10c63d2b2cSMatthias Springer 110bb4d4d3SMatthias Springer #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12920c4612SNicolas Vasilache #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 130bb4d4d3SMatthias Springer #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 143fe7127dSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformAttrs.h" 1530f22429SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h" 166fe03096SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformTypes.h" 1791856b34SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" 185a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 1942b16035SQuinn Dawkins #include "mlir/IR/BuiltinAttributes.h" 2063c9d2b1SAlex Zinenko #include "mlir/IR/Diagnostics.h" 212c95ede4SMatthias Springer #include "mlir/IR/Dominance.h" 22e4b04b39SOleksandr "Alex" Zinenko #include "mlir/IR/OpImplementation.h" 23a8cfa7cbSJakub Kuderski #include "mlir/IR/OperationSupport.h" 2430f22429SAlex Zinenko #include "mlir/IR/PatternMatch.h" 257dfcd4b7SMatthias Springer #include "mlir/IR/Verifier.h" 2642b16035SQuinn Dawkins #include "mlir/Interfaces/CallInterfaces.h" 2773c3dff1SAlex Zinenko #include "mlir/Interfaces/ControlFlowInterfaces.h" 2834a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionImplementation.h" 29633d9184SOleksandr "Alex" Zinenko #include "mlir/Interfaces/FunctionInterfaces.h" 3018ec2030SMatthias Springer #include "mlir/Pass/Pass.h" 3118ec2030SMatthias Springer #include "mlir/Pass/PassManager.h" 3218ec2030SMatthias Springer #include "mlir/Pass/PassRegistry.h" 332c95ede4SMatthias Springer #include "mlir/Transforms/CSE.h" 34bcfdb3e4SMatthias Springer #include "mlir/Transforms/DialectConversion.h" 350b52fa90SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 36fa1a23a7SMatthias Springer #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" 3742b16035SQuinn Dawkins #include "llvm/ADT/DenseSet.h" 38bba85ebdSAlex Zinenko #include "llvm/ADT/STLExtras.h" 3930f22429SAlex Zinenko #include "llvm/ADT/ScopeExit.h" 4063c9d2b1SAlex Zinenko #include "llvm/ADT/SmallPtrSet.h" 41f90b6090SOleksandr "Alex" Zinenko #include "llvm/ADT/TypeSwitch.h" 42e3890b7fSAlex Zinenko #include "llvm/Support/Debug.h" 4342b16035SQuinn Dawkins #include "llvm/Support/ErrorHandling.h" 4405423905SKazu Hirata #include <optional> 45e3890b7fSAlex Zinenko 46e3890b7fSAlex Zinenko #define DEBUG_TYPE "transform-dialect" 47e3890b7fSAlex Zinenko #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 480eb403adSAlex Zinenko 4963c9d2b1SAlex Zinenko #define DEBUG_TYPE_MATCHER "transform-matcher" 5063c9d2b1SAlex Zinenko #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") 5163c9d2b1SAlex Zinenko #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) 5263c9d2b1SAlex Zinenko 530eb403adSAlex Zinenko using namespace mlir; 540eb403adSAlex Zinenko 55b9e40cdeSAlex Zinenko static ParseResult parseSequenceOpOperands( 5605423905SKazu Hirata OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root, 57b9e40cdeSAlex Zinenko Type &rootType, 58b9e40cdeSAlex Zinenko SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings, 59b9e40cdeSAlex Zinenko SmallVectorImpl<Type> &extraBindingTypes); 60b9e40cdeSAlex Zinenko static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, 61b9e40cdeSAlex Zinenko Value root, Type rootType, 62b9e40cdeSAlex Zinenko ValueRange extraBindings, 63b9e40cdeSAlex Zinenko TypeRange extraBindingTypes); 6463c9d2b1SAlex Zinenko static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, 6563c9d2b1SAlex Zinenko ArrayAttr matchers, ArrayAttr actions); 6663c9d2b1SAlex Zinenko static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, 6763c9d2b1SAlex Zinenko ArrayAttr &matchers, 6863c9d2b1SAlex Zinenko ArrayAttr &actions); 69b9e40cdeSAlex Zinenko 7018ec2030SMatthias Springer /// Helper function to check if the given transform op is contained in (or 7118ec2030SMatthias Springer /// equal to) the given payload target op. In that case, an error is returned. 7218ec2030SMatthias Springer /// Transforming transform IR that is currently executing is generally unsafe. 7318ec2030SMatthias Springer static DiagnosedSilenceableFailure 7418ec2030SMatthias Springer ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, 7518ec2030SMatthias Springer Operation *payload) { 7618ec2030SMatthias Springer Operation *transformAncestor = transform.getOperation(); 7718ec2030SMatthias Springer while (transformAncestor) { 7818ec2030SMatthias Springer if (transformAncestor == payload) { 7918ec2030SMatthias Springer DiagnosedDefiniteFailure diag = 8018ec2030SMatthias Springer transform.emitDefiniteFailure() 8118ec2030SMatthias Springer << "cannot apply transform to itself (or one of its ancestors)"; 8218ec2030SMatthias Springer diag.attachNote(payload->getLoc()) << "target payload op"; 8318ec2030SMatthias Springer return diag; 8418ec2030SMatthias Springer } 8518ec2030SMatthias Springer transformAncestor = transformAncestor->getParentOp(); 8618ec2030SMatthias Springer } 8718ec2030SMatthias Springer return DiagnosedSilenceableFailure::success(); 8818ec2030SMatthias Springer } 8918ec2030SMatthias Springer 900eb403adSAlex Zinenko #define GET_OP_CLASSES 910eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 920eb403adSAlex Zinenko 9330f22429SAlex Zinenko //===----------------------------------------------------------------------===// 94e3890b7fSAlex Zinenko // AlternativesOp 95e3890b7fSAlex Zinenko //===----------------------------------------------------------------------===// 96e3890b7fSAlex Zinenko 974dd744acSMarkus Böck OperandRange 984dd744acSMarkus Böck transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { 994dd744acSMarkus Böck if (!point.isParent() && getOperation()->getNumOperands() == 1) 100e3890b7fSAlex Zinenko return getOperation()->getOperands(); 101e3890b7fSAlex Zinenko return OperandRange(getOperation()->operand_end(), 102e3890b7fSAlex Zinenko getOperation()->operand_end()); 103e3890b7fSAlex Zinenko } 104e3890b7fSAlex Zinenko 105e3890b7fSAlex Zinenko void transform::AlternativesOp::getSuccessorRegions( 1064dd744acSMarkus Böck RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 107491d2701SKazu Hirata for (Region &alternative : llvm::drop_begin( 1084dd744acSMarkus Böck getAlternatives(), 1094dd744acSMarkus Böck point.isParent() ? 0 1104dd744acSMarkus Böck : point.getRegionOrNull()->getRegionNumber() + 1)) { 111e3890b7fSAlex Zinenko regions.emplace_back(&alternative, !getOperands().empty() 112e3890b7fSAlex Zinenko ? alternative.getArguments() 113e3890b7fSAlex Zinenko : Block::BlockArgListType()); 114e3890b7fSAlex Zinenko } 1154dd744acSMarkus Böck if (!point.isParent()) 116e3890b7fSAlex Zinenko regions.emplace_back(getOperation()->getResults()); 117e3890b7fSAlex Zinenko } 118e3890b7fSAlex Zinenko 119e3890b7fSAlex Zinenko void transform::AlternativesOp::getRegionInvocationBounds( 120e3890b7fSAlex Zinenko ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 121e3890b7fSAlex Zinenko (void)operands; 122e3890b7fSAlex Zinenko // The region corresponding to the first alternative is always executed, the 123e3890b7fSAlex Zinenko // remaining may or may not be executed. 124e3890b7fSAlex Zinenko bounds.reserve(getNumRegions()); 125e3890b7fSAlex Zinenko bounds.emplace_back(1, 1); 126e3890b7fSAlex Zinenko bounds.resize(getNumRegions(), InvocationBounds(0, 1)); 127e3890b7fSAlex Zinenko } 128e3890b7fSAlex Zinenko 129aa6a6c56SNicolas Vasilache static void forwardEmptyOperands(Block *block, transform::TransformState &state, 130aa6a6c56SNicolas Vasilache transform::TransformResults &results) { 131aa6a6c56SNicolas Vasilache for (const auto &res : block->getParentOp()->getOpResults()) 132aa6a6c56SNicolas Vasilache results.set(res, {}); 133aa6a6c56SNicolas Vasilache } 134aa6a6c56SNicolas Vasilache 1351d45282aSAlex Zinenko DiagnosedSilenceableFailure 136c63d2b2cSMatthias Springer transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, 137c63d2b2cSMatthias Springer transform::TransformResults &results, 138e3890b7fSAlex Zinenko transform::TransformState &state) { 139e3890b7fSAlex Zinenko SmallVector<Operation *> originals; 140e3890b7fSAlex Zinenko if (Value scopeHandle = getScope()) 141e3890b7fSAlex Zinenko llvm::append_range(originals, state.getPayloadOps(scopeHandle)); 142e3890b7fSAlex Zinenko else 143e3890b7fSAlex Zinenko originals.push_back(state.getTopLevel()); 144e3890b7fSAlex Zinenko 145e3890b7fSAlex Zinenko for (Operation *original : originals) { 146e3890b7fSAlex Zinenko if (original->isAncestor(getOperation())) { 147b0bf7fffSAlex Zinenko auto diag = emitDefiniteFailure() 148b0bf7fffSAlex Zinenko << "scope must not contain the transforms being applied"; 149e3890b7fSAlex Zinenko diag.attachNote(original->getLoc()) << "scope"; 150b0bf7fffSAlex Zinenko return diag; 1511d45282aSAlex Zinenko } 1521d45282aSAlex Zinenko if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 153b0bf7fffSAlex Zinenko auto diag = emitDefiniteFailure() 1541d45282aSAlex Zinenko << "only isolated-from-above ops can be alternative scopes"; 1551d45282aSAlex Zinenko diag.attachNote(original->getLoc()) << "scope"; 156b0bf7fffSAlex Zinenko return diag; 157e3890b7fSAlex Zinenko } 158e3890b7fSAlex Zinenko } 159e3890b7fSAlex Zinenko 160e3890b7fSAlex Zinenko for (Region ® : getAlternatives()) { 161e3890b7fSAlex Zinenko // Clone the scope operations and make the transforms in this alternative 162e3890b7fSAlex Zinenko // region apply to them by virtue of mapping the block argument (the only 163e3890b7fSAlex Zinenko // visible handle) to the cloned scope operations. This effectively prevents 164e3890b7fSAlex Zinenko // the transformation from accessing any IR outside the scope. 165e3890b7fSAlex Zinenko auto scope = state.make_region_scope(reg); 166e3890b7fSAlex Zinenko auto clones = llvm::to_vector( 167e3890b7fSAlex Zinenko llvm::map_range(originals, [](Operation *op) { return op->clone(); })); 168e3890b7fSAlex Zinenko auto deleteClones = llvm::make_scope_exit([&] { 169e3890b7fSAlex Zinenko for (Operation *clone : clones) 170e3890b7fSAlex Zinenko clone->erase(); 171e3890b7fSAlex Zinenko }); 172bba85ebdSAlex Zinenko if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) 173bba85ebdSAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 174e3890b7fSAlex Zinenko 175e3890b7fSAlex Zinenko bool failed = false; 176e3890b7fSAlex Zinenko for (Operation &transform : reg.front().without_terminator()) { 1771d45282aSAlex Zinenko DiagnosedSilenceableFailure result = 178e3890b7fSAlex Zinenko state.applyTransform(cast<TransformOpInterface>(transform)); 1791d45282aSAlex Zinenko if (result.isSilenceableFailure()) { 180e3890b7fSAlex Zinenko LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() 181e3890b7fSAlex Zinenko << "\n"); 182e3890b7fSAlex Zinenko failed = true; 183e3890b7fSAlex Zinenko break; 184e3890b7fSAlex Zinenko } 185e3890b7fSAlex Zinenko 186e3890b7fSAlex Zinenko if (::mlir::failed(result.silence())) 1871d45282aSAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 188e3890b7fSAlex Zinenko } 189e3890b7fSAlex Zinenko 190e3890b7fSAlex Zinenko // If all operations in the given alternative succeeded, no need to consider 191e3890b7fSAlex Zinenko // the rest. Replace the original scoping operation with the clone on which 192e3890b7fSAlex Zinenko // the transformations were performed. 193e3890b7fSAlex Zinenko if (!failed) { 194e3890b7fSAlex Zinenko // We will be using the clones, so cancel their scheduled deletion. 195e3890b7fSAlex Zinenko deleteClones.release(); 196905e9324SMatthias Springer TrackingListener listener(state, *this); 19707fef178SMatthias Springer IRRewriter rewriter(getContext(), &listener); 198e3890b7fSAlex Zinenko for (const auto &kvp : llvm::zip(originals, clones)) { 199e3890b7fSAlex Zinenko Operation *original = std::get<0>(kvp); 200e3890b7fSAlex Zinenko Operation *clone = std::get<1>(kvp); 201e3890b7fSAlex Zinenko original->getBlock()->getOperations().insert(original->getIterator(), 202e3890b7fSAlex Zinenko clone); 203e3890b7fSAlex Zinenko rewriter.replaceOp(original, clone->getResults()); 204e3890b7fSAlex Zinenko } 20563c9d2b1SAlex Zinenko detail::forwardTerminatorOperands(®.front(), state, results); 2061d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 207e3890b7fSAlex Zinenko } 208e3890b7fSAlex Zinenko } 2091d45282aSAlex Zinenko return emitSilenceableError() << "all alternatives failed"; 210e3890b7fSAlex Zinenko } 211e3890b7fSAlex Zinenko 212d46afeefSAlex Zinenko void transform::AlternativesOp::getEffects( 213d46afeefSAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2142c1ae801Sdonald chen consumesHandle(getOperation()->getOpOperands(), effects); 2152c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 216d46afeefSAlex Zinenko for (Region *region : getRegions()) { 217d46afeefSAlex Zinenko if (!region->empty()) 218d46afeefSAlex Zinenko producesHandle(region->front().getArguments(), effects); 219d46afeefSAlex Zinenko } 220d46afeefSAlex Zinenko modifiesPayload(effects); 221d46afeefSAlex Zinenko } 222d46afeefSAlex Zinenko 223e3890b7fSAlex Zinenko LogicalResult transform::AlternativesOp::verify() { 224e3890b7fSAlex Zinenko for (Region &alternative : getAlternatives()) { 225e3890b7fSAlex Zinenko Block &block = alternative.front(); 226e3890b7fSAlex Zinenko Operation *terminator = block.getTerminator(); 227e3890b7fSAlex Zinenko if (terminator->getOperands().getTypes() != getResults().getTypes()) { 228e3890b7fSAlex Zinenko InFlightDiagnostic diag = emitOpError() 229e3890b7fSAlex Zinenko << "expects terminator operands to have the " 230e3890b7fSAlex Zinenko "same type as results of the operation"; 231e3890b7fSAlex Zinenko diag.attachNote(terminator->getLoc()) << "terminator"; 232e3890b7fSAlex Zinenko return diag; 233e3890b7fSAlex Zinenko } 234e3890b7fSAlex Zinenko } 235e3890b7fSAlex Zinenko 236e3890b7fSAlex Zinenko return success(); 237e3890b7fSAlex Zinenko } 238e3890b7fSAlex Zinenko 239e3890b7fSAlex Zinenko //===----------------------------------------------------------------------===// 2405e7ac250SQuinn Dawkins // AnnotateOp 2415e7ac250SQuinn Dawkins //===----------------------------------------------------------------------===// 2425e7ac250SQuinn Dawkins 2435e7ac250SQuinn Dawkins DiagnosedSilenceableFailure 244c63d2b2cSMatthias Springer transform::AnnotateOp::apply(transform::TransformRewriter &rewriter, 245c63d2b2cSMatthias Springer transform::TransformResults &results, 2465e7ac250SQuinn Dawkins transform::TransformState &state) { 2475e7ac250SQuinn Dawkins SmallVector<Operation *> targets = 2485e7ac250SQuinn Dawkins llvm::to_vector(state.getPayloadOps(getTarget())); 2495e7ac250SQuinn Dawkins 2505e7ac250SQuinn Dawkins Attribute attr = UnitAttr::get(getContext()); 2515e7ac250SQuinn Dawkins if (auto paramH = getParam()) { 2525e7ac250SQuinn Dawkins ArrayRef<Attribute> params = state.getParams(paramH); 2535e7ac250SQuinn Dawkins if (params.size() != 1) { 2545e7ac250SQuinn Dawkins if (targets.size() != params.size()) { 2555e7ac250SQuinn Dawkins return emitSilenceableError() 2565e7ac250SQuinn Dawkins << "parameter and target have different payload lengths (" 2575e7ac250SQuinn Dawkins << params.size() << " vs " << targets.size() << ")"; 2585e7ac250SQuinn Dawkins } 2595e7ac250SQuinn Dawkins for (auto &&[target, attr] : llvm::zip_equal(targets, params)) 2605e7ac250SQuinn Dawkins target->setAttr(getName(), attr); 2615e7ac250SQuinn Dawkins return DiagnosedSilenceableFailure::success(); 2625e7ac250SQuinn Dawkins } 2635e7ac250SQuinn Dawkins attr = params[0]; 2645e7ac250SQuinn Dawkins } 265153661dbSMehdi Amini for (auto *target : targets) 2665e7ac250SQuinn Dawkins target->setAttr(getName(), attr); 2675e7ac250SQuinn Dawkins return DiagnosedSilenceableFailure::success(); 2685e7ac250SQuinn Dawkins } 2695e7ac250SQuinn Dawkins 2705e7ac250SQuinn Dawkins void transform::AnnotateOp::getEffects( 2715e7ac250SQuinn Dawkins SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2722c1ae801Sdonald chen onlyReadsHandle(getTargetMutable(), effects); 2732c1ae801Sdonald chen onlyReadsHandle(getParamMutable(), effects); 2745e7ac250SQuinn Dawkins modifiesPayload(effects); 2755e7ac250SQuinn Dawkins } 2765e7ac250SQuinn Dawkins 2775e7ac250SQuinn Dawkins //===----------------------------------------------------------------------===// 2782c95ede4SMatthias Springer // ApplyCommonSubexpressionEliminationOp 2792c95ede4SMatthias Springer //===----------------------------------------------------------------------===// 2802c95ede4SMatthias Springer 2812c95ede4SMatthias Springer DiagnosedSilenceableFailure 2822c95ede4SMatthias Springer transform::ApplyCommonSubexpressionEliminationOp::applyToOne( 2832c95ede4SMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 2842c95ede4SMatthias Springer ApplyToEachResultList &results, transform::TransformState &state) { 2852c95ede4SMatthias Springer // Make sure that this transform is not applied to itself. Modifying the 2862c95ede4SMatthias Springer // transform IR while it is being interpreted is generally dangerous. 2872c95ede4SMatthias Springer DiagnosedSilenceableFailure payloadCheck = 2882c95ede4SMatthias Springer ensurePayloadIsSeparateFromTransform(*this, target); 2892c95ede4SMatthias Springer if (!payloadCheck.succeeded()) 2902c95ede4SMatthias Springer return payloadCheck; 2912c95ede4SMatthias Springer 2922c95ede4SMatthias Springer DominanceInfo domInfo; 2932c95ede4SMatthias Springer mlir::eliminateCommonSubExpressions(rewriter, domInfo, target); 2942c95ede4SMatthias Springer return DiagnosedSilenceableFailure::success(); 2952c95ede4SMatthias Springer } 2962c95ede4SMatthias Springer 2972c95ede4SMatthias Springer void transform::ApplyCommonSubexpressionEliminationOp::getEffects( 2982c95ede4SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2992c1ae801Sdonald chen transform::onlyReadsHandle(getTargetMutable(), effects); 3002c95ede4SMatthias Springer transform::modifiesPayload(effects); 3012c95ede4SMatthias Springer } 3022c95ede4SMatthias Springer 3032c95ede4SMatthias Springer //===----------------------------------------------------------------------===// 304c2d5d348SMatthias Springer // ApplyDeadCodeEliminationOp 305c2d5d348SMatthias Springer //===----------------------------------------------------------------------===// 306c2d5d348SMatthias Springer 307c2d5d348SMatthias Springer DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne( 308c2d5d348SMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 309c2d5d348SMatthias Springer ApplyToEachResultList &results, transform::TransformState &state) { 310c2d5d348SMatthias Springer // Make sure that this transform is not applied to itself. Modifying the 311c2d5d348SMatthias Springer // transform IR while it is being interpreted is generally dangerous. 312c2d5d348SMatthias Springer DiagnosedSilenceableFailure payloadCheck = 313c2d5d348SMatthias Springer ensurePayloadIsSeparateFromTransform(*this, target); 314c2d5d348SMatthias Springer if (!payloadCheck.succeeded()) 315c2d5d348SMatthias Springer return payloadCheck; 316c2d5d348SMatthias Springer 317c2d5d348SMatthias Springer // Maintain a worklist of potentially dead ops. 318c2d5d348SMatthias Springer SetVector<Operation *> worklist; 319c2d5d348SMatthias Springer 320c2d5d348SMatthias Springer // Helper function that adds all defining ops of used values (operands and 321c2d5d348SMatthias Springer // operands of nested ops). 322c2d5d348SMatthias Springer auto addDefiningOpsToWorklist = [&](Operation *op) { 323c2d5d348SMatthias Springer op->walk([&](Operation *op) { 324c2d5d348SMatthias Springer for (Value v : op->getOperands()) 325c2d5d348SMatthias Springer if (Operation *defOp = v.getDefiningOp()) 326c2d5d348SMatthias Springer if (target->isProperAncestor(defOp)) 327c2d5d348SMatthias Springer worklist.insert(defOp); 328c2d5d348SMatthias Springer }); 329c2d5d348SMatthias Springer }; 330c2d5d348SMatthias Springer 331c2d5d348SMatthias Springer // Helper function that erases an op. 332c2d5d348SMatthias Springer auto eraseOp = [&](Operation *op) { 333c2d5d348SMatthias Springer // Remove op and nested ops from the worklist. 334c2d5d348SMatthias Springer op->walk([&](Operation *op) { 335153661dbSMehdi Amini const auto *it = llvm::find(worklist, op); 336c2d5d348SMatthias Springer if (it != worklist.end()) 337c2d5d348SMatthias Springer worklist.erase(it); 338c2d5d348SMatthias Springer }); 339c2d5d348SMatthias Springer rewriter.eraseOp(op); 340c2d5d348SMatthias Springer }; 341c2d5d348SMatthias Springer 342c2d5d348SMatthias Springer // Initial walk over the IR. 343c2d5d348SMatthias Springer target->walk<WalkOrder::PostOrder>([&](Operation *op) { 344c2d5d348SMatthias Springer if (op != target && isOpTriviallyDead(op)) { 345c2d5d348SMatthias Springer addDefiningOpsToWorklist(op); 346c2d5d348SMatthias Springer eraseOp(op); 347c2d5d348SMatthias Springer } 348c2d5d348SMatthias Springer }); 349c2d5d348SMatthias Springer 350c2d5d348SMatthias Springer // Erase all ops that have become dead. 351c2d5d348SMatthias Springer while (!worklist.empty()) { 352c2d5d348SMatthias Springer Operation *op = worklist.pop_back_val(); 353c2d5d348SMatthias Springer if (!isOpTriviallyDead(op)) 354c2d5d348SMatthias Springer continue; 355c2d5d348SMatthias Springer addDefiningOpsToWorklist(op); 356c2d5d348SMatthias Springer eraseOp(op); 357c2d5d348SMatthias Springer } 358c2d5d348SMatthias Springer 359c2d5d348SMatthias Springer return DiagnosedSilenceableFailure::success(); 360c2d5d348SMatthias Springer } 361c2d5d348SMatthias Springer 362c2d5d348SMatthias Springer void transform::ApplyDeadCodeEliminationOp::getEffects( 363c2d5d348SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 3642c1ae801Sdonald chen transform::onlyReadsHandle(getTargetMutable(), effects); 365c2d5d348SMatthias Springer transform::modifiesPayload(effects); 366c2d5d348SMatthias Springer } 367c2d5d348SMatthias Springer 368c2d5d348SMatthias Springer //===----------------------------------------------------------------------===// 3690b52fa90SMatthias Springer // ApplyPatternsOp 3700b52fa90SMatthias Springer //===----------------------------------------------------------------------===// 3710b52fa90SMatthias Springer 372c63d2b2cSMatthias Springer DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( 373c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 374c63d2b2cSMatthias Springer ApplyToEachResultList &results, transform::TransformState &state) { 375726d0767SMatthias Springer // Make sure that this transform is not applied to itself. Modifying the 376726d0767SMatthias Springer // transform IR while it is being interpreted is generally dangerous. Even 377726d0767SMatthias Springer // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver 378726d0767SMatthias Springer // performs many additional simplifications such as dead code elimination. 37918ec2030SMatthias Springer DiagnosedSilenceableFailure payloadCheck = 38018ec2030SMatthias Springer ensurePayloadIsSeparateFromTransform(*this, target); 38118ec2030SMatthias Springer if (!payloadCheck.succeeded()) 38218ec2030SMatthias Springer return payloadCheck; 383726d0767SMatthias Springer 3840b52fa90SMatthias Springer // Gather all specified patterns. 3850b52fa90SMatthias Springer MLIRContext *ctx = target->getContext(); 3860b52fa90SMatthias Springer RewritePatternSet patterns(ctx); 3875a10f207SMatthias Springer if (!getRegion().empty()) { 3885a10f207SMatthias Springer for (Operation &op : getRegion().front()) { 389e55e36deSOleksandr "Alex" Zinenko cast<transform::PatternDescriptorOpInterface>(&op) 390e55e36deSOleksandr "Alex" Zinenko .populatePatternsWithState(patterns, state); 3915a10f207SMatthias Springer } 3925a10f207SMatthias Springer } 3930b52fa90SMatthias Springer 3940b52fa90SMatthias Springer // Configure the GreedyPatternRewriteDriver. 3950b52fa90SMatthias Springer GreedyRewriteConfig config; 396c63d2b2cSMatthias Springer config.listener = 397c63d2b2cSMatthias Springer static_cast<RewriterBase::Listener *>(rewriter.getListener()); 39820245ed4SMatthias Springer FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 3990b52fa90SMatthias Springer 40037b26bf4SOleksandr "Alex" Zinenko config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1) 40137b26bf4SOleksandr "Alex" Zinenko ? GreedyRewriteConfig::kNoLimit 40237b26bf4SOleksandr "Alex" Zinenko : getMaxIterations(); 40337b26bf4SOleksandr "Alex" Zinenko config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1) 40437b26bf4SOleksandr "Alex" Zinenko ? GreedyRewriteConfig::kNoLimit 40537b26bf4SOleksandr "Alex" Zinenko : getMaxNumRewrites(); 40637b26bf4SOleksandr "Alex" Zinenko 40720245ed4SMatthias Springer // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE 40820245ed4SMatthias Springer // was requested, apply the greedy pattern rewrite only once. (The greedy 40920245ed4SMatthias Springer // pattern rewrite driver already iterates to a fixpoint internally.) 41020245ed4SMatthias Springer bool cseChanged = false; 41120245ed4SMatthias Springer // One or two iterations should be sufficient. Stop iterating after a certain 41220245ed4SMatthias Springer // threshold to make debugging easier. 41320245ed4SMatthias Springer static const int64_t kNumMaxIterations = 50; 41420245ed4SMatthias Springer int64_t iteration = 0; 41520245ed4SMatthias Springer do { 416976d25edSMatthias Springer LogicalResult result = failure(); 417976d25edSMatthias Springer if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 418976d25edSMatthias Springer // Op is isolated from above. Apply patterns and also perform region 419976d25edSMatthias Springer // simplification. 42009dfc571SJacques Pienaar result = applyPatternsGreedily(target, frozenPatterns, config); 421976d25edSMatthias Springer } else { 42220245ed4SMatthias Springer // Manually gather list of ops because the other 42320245ed4SMatthias Springer // GreedyPatternRewriteDriver overloads only accepts ops that are isolated 42420245ed4SMatthias Springer // from above. This way, patterns can be applied to ops that are not 42520245ed4SMatthias Springer // isolated from above. Regions are not being simplified. Furthermore, 42620245ed4SMatthias Springer // only a single greedy rewrite iteration is performed. 4270b52fa90SMatthias Springer SmallVector<Operation *> ops; 4280b52fa90SMatthias Springer target->walk([&](Operation *nestedOp) { 4290b52fa90SMatthias Springer if (target != nestedOp) 4300b52fa90SMatthias Springer ops.push_back(nestedOp); 4310b52fa90SMatthias Springer }); 43209dfc571SJacques Pienaar result = applyOpPatternsGreedily(ops, frozenPatterns, config); 433976d25edSMatthias Springer } 434976d25edSMatthias Springer 4350b52fa90SMatthias Springer // A failure typically indicates that the pattern application did not 4360b52fa90SMatthias Springer // converge. 4370b52fa90SMatthias Springer if (failed(result)) { 4380b52fa90SMatthias Springer return emitSilenceableFailure(target) 4390b52fa90SMatthias Springer << "greedy pattern application failed"; 4400b52fa90SMatthias Springer } 4410b52fa90SMatthias Springer 44220245ed4SMatthias Springer if (getApplyCse()) { 44320245ed4SMatthias Springer DominanceInfo domInfo; 44420245ed4SMatthias Springer mlir::eliminateCommonSubExpressions(rewriter, domInfo, target, 44520245ed4SMatthias Springer &cseChanged); 44620245ed4SMatthias Springer } 44720245ed4SMatthias Springer } while (cseChanged && ++iteration < kNumMaxIterations); 44820245ed4SMatthias Springer 44920245ed4SMatthias Springer if (iteration == kNumMaxIterations) 45020245ed4SMatthias Springer return emitDefiniteFailure() << "fixpoint iteration did not converge"; 45120245ed4SMatthias Springer 4520b52fa90SMatthias Springer return DiagnosedSilenceableFailure::success(); 4530b52fa90SMatthias Springer } 4540b52fa90SMatthias Springer 4550b52fa90SMatthias Springer LogicalResult transform::ApplyPatternsOp::verify() { 4565a10f207SMatthias Springer if (!getRegion().empty()) { 4575a10f207SMatthias Springer for (Operation &op : getRegion().front()) { 4585a10f207SMatthias Springer if (!isa<transform::PatternDescriptorOpInterface>(&op)) { 4595a10f207SMatthias Springer InFlightDiagnostic diag = emitOpError() 4605a10f207SMatthias Springer << "expected children ops to implement " 4615a10f207SMatthias Springer "PatternDescriptorOpInterface"; 4625a10f207SMatthias Springer diag.attachNote(op.getLoc()) << "op without interface"; 4635a10f207SMatthias Springer return diag; 4645a10f207SMatthias Springer } 4655a10f207SMatthias Springer } 4665a10f207SMatthias Springer } 4670b52fa90SMatthias Springer return success(); 4680b52fa90SMatthias Springer } 4690b52fa90SMatthias Springer 4700b52fa90SMatthias Springer void transform::ApplyPatternsOp::getEffects( 4710b52fa90SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 4722c1ae801Sdonald chen transform::onlyReadsHandle(getTargetMutable(), effects); 4730b52fa90SMatthias Springer transform::modifiesPayload(effects); 4740b52fa90SMatthias Springer } 4750b52fa90SMatthias Springer 476223a0f63SMatthias Springer void transform::ApplyPatternsOp::build( 477223a0f63SMatthias Springer OpBuilder &builder, OperationState &result, Value target, 478c63d2b2cSMatthias Springer function_ref<void(OpBuilder &, Location)> bodyBuilder) { 479223a0f63SMatthias Springer result.addOperands(target); 480223a0f63SMatthias Springer 481223a0f63SMatthias Springer OpBuilder::InsertionGuard g(builder); 482223a0f63SMatthias Springer Region *region = result.addRegion(); 483223a0f63SMatthias Springer builder.createBlock(region); 484223a0f63SMatthias Springer if (bodyBuilder) 485223a0f63SMatthias Springer bodyBuilder(builder, result.location); 486223a0f63SMatthias Springer } 487223a0f63SMatthias Springer 4880b52fa90SMatthias Springer //===----------------------------------------------------------------------===// 4895a10f207SMatthias Springer // ApplyCanonicalizationPatternsOp 4905a10f207SMatthias Springer //===----------------------------------------------------------------------===// 4915a10f207SMatthias Springer 4925a10f207SMatthias Springer void transform::ApplyCanonicalizationPatternsOp::populatePatterns( 4935a10f207SMatthias Springer RewritePatternSet &patterns) { 4945a10f207SMatthias Springer MLIRContext *ctx = patterns.getContext(); 4955a10f207SMatthias Springer for (Dialect *dialect : ctx->getLoadedDialects()) 4965a10f207SMatthias Springer dialect->getCanonicalizationPatterns(patterns); 4975a10f207SMatthias Springer for (RegisteredOperationName op : ctx->getRegisteredOperations()) 4985a10f207SMatthias Springer op.getCanonicalizationPatterns(patterns, ctx); 4995a10f207SMatthias Springer } 5005a10f207SMatthias Springer 5015a10f207SMatthias Springer //===----------------------------------------------------------------------===// 502bcfdb3e4SMatthias Springer // ApplyConversionPatternsOp 503bcfdb3e4SMatthias Springer //===----------------------------------------------------------------------===// 504bcfdb3e4SMatthias Springer 505bcfdb3e4SMatthias Springer DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( 506bcfdb3e4SMatthias Springer transform::TransformRewriter &rewriter, 507bcfdb3e4SMatthias Springer transform::TransformResults &results, transform::TransformState &state) { 508bcfdb3e4SMatthias Springer MLIRContext *ctx = getContext(); 509bcfdb3e4SMatthias Springer 51099475f5bSNicolas Vasilache // Instantiate the default type converter if a type converter builder is 51199475f5bSNicolas Vasilache // specified. 512bcfdb3e4SMatthias Springer std::unique_ptr<TypeConverter> defaultTypeConverter; 51399475f5bSNicolas Vasilache transform::TypeConverterBuilderOpInterface typeConverterBuilder = 51499475f5bSNicolas Vasilache getDefaultTypeConverter(); 51599475f5bSNicolas Vasilache if (typeConverterBuilder) 51699475f5bSNicolas Vasilache defaultTypeConverter = typeConverterBuilder.getTypeConverter(); 517bcfdb3e4SMatthias Springer 518bcfdb3e4SMatthias Springer // Configure conversion target. 519920c4612SNicolas Vasilache ConversionTarget conversionTarget(*getContext()); 520bcfdb3e4SMatthias Springer if (getLegalOps()) 521bcfdb3e4SMatthias Springer for (Attribute attr : cast<ArrayAttr>(*getLegalOps())) 522bcfdb3e4SMatthias Springer conversionTarget.addLegalOp( 523bcfdb3e4SMatthias Springer OperationName(cast<StringAttr>(attr).getValue(), ctx)); 524bcfdb3e4SMatthias Springer if (getIllegalOps()) 525bcfdb3e4SMatthias Springer for (Attribute attr : cast<ArrayAttr>(*getIllegalOps())) 526bcfdb3e4SMatthias Springer conversionTarget.addIllegalOp( 527bcfdb3e4SMatthias Springer OperationName(cast<StringAttr>(attr).getValue(), ctx)); 528bcfdb3e4SMatthias Springer if (getLegalDialects()) 529bcfdb3e4SMatthias Springer for (Attribute attr : cast<ArrayAttr>(*getLegalDialects())) 530bcfdb3e4SMatthias Springer conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue()); 531bcfdb3e4SMatthias Springer if (getIllegalDialects()) 532bcfdb3e4SMatthias Springer for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects())) 533bcfdb3e4SMatthias Springer conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue()); 534bcfdb3e4SMatthias Springer 535bcfdb3e4SMatthias Springer // Gather all specified patterns. 536bcfdb3e4SMatthias Springer RewritePatternSet patterns(ctx); 53799475f5bSNicolas Vasilache // Need to keep the converters alive until after pattern application because 53899475f5bSNicolas Vasilache // the patterns take a reference to an object that would otherwise get out of 53999475f5bSNicolas Vasilache // scope. 54099475f5bSNicolas Vasilache SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters; 541bcfdb3e4SMatthias Springer if (!getPatterns().empty()) { 542bcfdb3e4SMatthias Springer for (Operation &op : getPatterns().front()) { 543bcfdb3e4SMatthias Springer auto descriptor = 544bcfdb3e4SMatthias Springer cast<transform::ConversionPatternDescriptorOpInterface>(&op); 545bcfdb3e4SMatthias Springer 546bcfdb3e4SMatthias Springer // Check if this pattern set specifies a type converter. 547bcfdb3e4SMatthias Springer std::unique_ptr<TypeConverter> typeConverter = 548bcfdb3e4SMatthias Springer descriptor.getTypeConverter(); 549bcfdb3e4SMatthias Springer TypeConverter *converter = nullptr; 550bcfdb3e4SMatthias Springer if (typeConverter) { 55199475f5bSNicolas Vasilache keepAliveConverters.emplace_back(std::move(typeConverter)); 55299475f5bSNicolas Vasilache converter = keepAliveConverters.back().get(); 553bcfdb3e4SMatthias Springer } else { 554bcfdb3e4SMatthias Springer // No type converter specified: Use the default type converter. 555bcfdb3e4SMatthias Springer if (!defaultTypeConverter) { 556bcfdb3e4SMatthias Springer auto diag = emitDefiniteFailure() 557bcfdb3e4SMatthias Springer << "pattern descriptor does not specify type " 558bcfdb3e4SMatthias Springer "converter and apply_conversion_patterns op has " 559bcfdb3e4SMatthias Springer "no default type converter"; 560bcfdb3e4SMatthias Springer diag.attachNote(op.getLoc()) << "pattern descriptor op"; 561bcfdb3e4SMatthias Springer return diag; 562bcfdb3e4SMatthias Springer } 563bcfdb3e4SMatthias Springer converter = defaultTypeConverter.get(); 564bcfdb3e4SMatthias Springer } 565e2d39f79SChristopher Bate 566e2d39f79SChristopher Bate // Add descriptor-specific updates to the conversion target, which may 567e2d39f79SChristopher Bate // depend on the final type converter. In structural converters, the 568e2d39f79SChristopher Bate // legality of types dictates the dynamic legality of an operation. 569e2d39f79SChristopher Bate descriptor.populateConversionTargetRules(*converter, conversionTarget); 570e2d39f79SChristopher Bate 571bcfdb3e4SMatthias Springer descriptor.populatePatterns(*converter, patterns); 572bcfdb3e4SMatthias Springer } 573bcfdb3e4SMatthias Springer } 574bcfdb3e4SMatthias Springer 575c1029b6aSMatthias Springer // Attach a tracking listener if handles should be preserved. We configure the 576c1029b6aSMatthias Springer // listener to allow op replacements with different names, as conversion 577c1029b6aSMatthias Springer // patterns typically replace ops with replacement ops that have a different 578c1029b6aSMatthias Springer // name. 579c1029b6aSMatthias Springer TrackingListenerConfig trackingConfig; 580c1029b6aSMatthias Springer trackingConfig.requireMatchingReplacementOpName = false; 581c1029b6aSMatthias Springer ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig); 582c1029b6aSMatthias Springer ConversionConfig conversionConfig; 583c1029b6aSMatthias Springer if (getPreserveHandles()) 584c1029b6aSMatthias Springer conversionConfig.listener = &trackingListener; 585c1029b6aSMatthias Springer 58699475f5bSNicolas Vasilache FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 587bcfdb3e4SMatthias Springer for (Operation *target : state.getPayloadOps(getTarget())) { 588bcfdb3e4SMatthias Springer // Make sure that this transform is not applied to itself. Modifying the 589bcfdb3e4SMatthias Springer // transform IR while it is being interpreted is generally dangerous. 590bcfdb3e4SMatthias Springer DiagnosedSilenceableFailure payloadCheck = 591bcfdb3e4SMatthias Springer ensurePayloadIsSeparateFromTransform(*this, target); 592bcfdb3e4SMatthias Springer if (!payloadCheck.succeeded()) 593bcfdb3e4SMatthias Springer return payloadCheck; 594bcfdb3e4SMatthias Springer 595bcfdb3e4SMatthias Springer LogicalResult status = failure(); 596bcfdb3e4SMatthias Springer if (getPartialConversion()) { 597c1029b6aSMatthias Springer status = applyPartialConversion(target, conversionTarget, frozenPatterns, 598c1029b6aSMatthias Springer conversionConfig); 599bcfdb3e4SMatthias Springer } else { 600c1029b6aSMatthias Springer status = applyFullConversion(target, conversionTarget, frozenPatterns, 601c1029b6aSMatthias Springer conversionConfig); 602bcfdb3e4SMatthias Springer } 603bcfdb3e4SMatthias Springer 604c1029b6aSMatthias Springer // Check dialect conversion state. 605c1029b6aSMatthias Springer DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); 606bcfdb3e4SMatthias Springer if (failed(status)) { 607c1029b6aSMatthias Springer diag = emitSilenceableError() << "dialect conversion failed"; 608bcfdb3e4SMatthias Springer diag.attachNote(target->getLoc()) << "target op"; 609bcfdb3e4SMatthias Springer } 610c1029b6aSMatthias Springer 611c1029b6aSMatthias Springer // Check tracking listener error state. 612c1029b6aSMatthias Springer DiagnosedSilenceableFailure trackingFailure = 613c1029b6aSMatthias Springer trackingListener.checkAndResetError(); 614c1029b6aSMatthias Springer if (!trackingFailure.succeeded()) { 615c1029b6aSMatthias Springer if (diag.succeeded()) { 616c1029b6aSMatthias Springer // Tracking failure is the only failure. 617c1029b6aSMatthias Springer return trackingFailure; 618c1029b6aSMatthias Springer } else { 619c1029b6aSMatthias Springer diag.attachNote() << "tracking listener also failed: " 620c1029b6aSMatthias Springer << trackingFailure.getMessage(); 621c1029b6aSMatthias Springer (void)trackingFailure.silence(); 622c1029b6aSMatthias Springer } 623c1029b6aSMatthias Springer } 624c1029b6aSMatthias Springer 625c1029b6aSMatthias Springer if (!diag.succeeded()) 626c1029b6aSMatthias Springer return diag; 627bcfdb3e4SMatthias Springer } 628bcfdb3e4SMatthias Springer 629bcfdb3e4SMatthias Springer return DiagnosedSilenceableFailure::success(); 630bcfdb3e4SMatthias Springer } 631bcfdb3e4SMatthias Springer 632bcfdb3e4SMatthias Springer LogicalResult transform::ApplyConversionPatternsOp::verify() { 633bcfdb3e4SMatthias Springer if (getNumRegions() != 1 && getNumRegions() != 2) 634bcfdb3e4SMatthias Springer return emitOpError() << "expected 1 or 2 regions"; 635bcfdb3e4SMatthias Springer if (!getPatterns().empty()) { 636bcfdb3e4SMatthias Springer for (Operation &op : getPatterns().front()) { 637bcfdb3e4SMatthias Springer if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) { 638bcfdb3e4SMatthias Springer InFlightDiagnostic diag = 639bcfdb3e4SMatthias Springer emitOpError() << "expected pattern children ops to implement " 640bcfdb3e4SMatthias Springer "ConversionPatternDescriptorOpInterface"; 641bcfdb3e4SMatthias Springer diag.attachNote(op.getLoc()) << "op without interface"; 642bcfdb3e4SMatthias Springer return diag; 643bcfdb3e4SMatthias Springer } 644bcfdb3e4SMatthias Springer } 645bcfdb3e4SMatthias Springer } 646bcfdb3e4SMatthias Springer if (getNumRegions() == 2) { 647bcfdb3e4SMatthias Springer Region &typeConverterRegion = getRegion(1); 648bcfdb3e4SMatthias Springer if (!llvm::hasSingleElement(typeConverterRegion.front())) 649bcfdb3e4SMatthias Springer return emitOpError() 650bcfdb3e4SMatthias Springer << "expected exactly one op in default type converter region"; 6514527adc5SDaniel Kuts Operation *maybeTypeConverter = &typeConverterRegion.front().front(); 6527ec88f06SMatthias Springer auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>( 6534527adc5SDaniel Kuts maybeTypeConverter); 6547ec88f06SMatthias Springer if (!typeConverterOp) { 655bcfdb3e4SMatthias Springer InFlightDiagnostic diag = emitOpError() 656bcfdb3e4SMatthias Springer << "expected default converter child op to " 657bcfdb3e4SMatthias Springer "implement TypeConverterBuilderOpInterface"; 6584527adc5SDaniel Kuts diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface"; 659bcfdb3e4SMatthias Springer return diag; 660bcfdb3e4SMatthias Springer } 6617ec88f06SMatthias Springer // Check default type converter type. 6627ec88f06SMatthias Springer if (!getPatterns().empty()) { 6637ec88f06SMatthias Springer for (Operation &op : getPatterns().front()) { 6647ec88f06SMatthias Springer auto descriptor = 6657ec88f06SMatthias Springer cast<transform::ConversionPatternDescriptorOpInterface>(&op); 6667ec88f06SMatthias Springer if (failed(descriptor.verifyTypeConverter(typeConverterOp))) 6677ec88f06SMatthias Springer return failure(); 6687ec88f06SMatthias Springer } 6697ec88f06SMatthias Springer } 670bcfdb3e4SMatthias Springer } 671bcfdb3e4SMatthias Springer return success(); 672bcfdb3e4SMatthias Springer } 673bcfdb3e4SMatthias Springer 674bcfdb3e4SMatthias Springer void transform::ApplyConversionPatternsOp::getEffects( 675bcfdb3e4SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 676c1029b6aSMatthias Springer if (!getPreserveHandles()) { 6772c1ae801Sdonald chen transform::consumesHandle(getTargetMutable(), effects); 678c1029b6aSMatthias Springer } else { 6792c1ae801Sdonald chen transform::onlyReadsHandle(getTargetMutable(), effects); 680c1029b6aSMatthias Springer } 681bcfdb3e4SMatthias Springer transform::modifiesPayload(effects); 682bcfdb3e4SMatthias Springer } 683bcfdb3e4SMatthias Springer 684bcfdb3e4SMatthias Springer void transform::ApplyConversionPatternsOp::build( 685bcfdb3e4SMatthias Springer OpBuilder &builder, OperationState &result, Value target, 686bcfdb3e4SMatthias Springer function_ref<void(OpBuilder &, Location)> patternsBodyBuilder, 687bcfdb3e4SMatthias Springer function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) { 688bcfdb3e4SMatthias Springer result.addOperands(target); 689bcfdb3e4SMatthias Springer 690bcfdb3e4SMatthias Springer { 691bcfdb3e4SMatthias Springer OpBuilder::InsertionGuard g(builder); 692bcfdb3e4SMatthias Springer Region *region1 = result.addRegion(); 693bcfdb3e4SMatthias Springer builder.createBlock(region1); 694bcfdb3e4SMatthias Springer if (patternsBodyBuilder) 695bcfdb3e4SMatthias Springer patternsBodyBuilder(builder, result.location); 696bcfdb3e4SMatthias Springer } 697bcfdb3e4SMatthias Springer { 698bcfdb3e4SMatthias Springer OpBuilder::InsertionGuard g(builder); 699bcfdb3e4SMatthias Springer Region *region2 = result.addRegion(); 700bcfdb3e4SMatthias Springer builder.createBlock(region2); 701bcfdb3e4SMatthias Springer if (typeConverterBodyBuilder) 702bcfdb3e4SMatthias Springer typeConverterBodyBuilder(builder, result.location); 703bcfdb3e4SMatthias Springer } 704bcfdb3e4SMatthias Springer } 705bcfdb3e4SMatthias Springer 706bcfdb3e4SMatthias Springer //===----------------------------------------------------------------------===// 7070bb4d4d3SMatthias Springer // ApplyToLLVMConversionPatternsOp 7080bb4d4d3SMatthias Springer //===----------------------------------------------------------------------===// 7090bb4d4d3SMatthias Springer 7100bb4d4d3SMatthias Springer void transform::ApplyToLLVMConversionPatternsOp::populatePatterns( 7110bb4d4d3SMatthias Springer TypeConverter &typeConverter, RewritePatternSet &patterns) { 7120bb4d4d3SMatthias Springer Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); 7130bb4d4d3SMatthias Springer assert(dialect && "expected that dialect is loaded"); 714153661dbSMehdi Amini auto *iface = cast<ConvertToLLVMPatternInterface>(dialect); 7150bb4d4d3SMatthias Springer // ConversionTarget is currently ignored because the enclosing 7160bb4d4d3SMatthias Springer // apply_conversion_patterns op sets up its own ConversionTarget. 7170bb4d4d3SMatthias Springer ConversionTarget target(*getContext()); 7180bb4d4d3SMatthias Springer iface->populateConvertToLLVMConversionPatterns( 7190bb4d4d3SMatthias Springer target, static_cast<LLVMTypeConverter &>(typeConverter), patterns); 7200bb4d4d3SMatthias Springer } 7210bb4d4d3SMatthias Springer 7220bb4d4d3SMatthias Springer LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter( 7230bb4d4d3SMatthias Springer transform::TypeConverterBuilderOpInterface builder) { 7240bb4d4d3SMatthias Springer if (builder.getTypeConverterType() != "LLVMTypeConverter") 7250bb4d4d3SMatthias Springer return emitOpError("expected LLVMTypeConverter"); 7260bb4d4d3SMatthias Springer return success(); 7270bb4d4d3SMatthias Springer } 7280bb4d4d3SMatthias Springer 7290bb4d4d3SMatthias Springer LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() { 7300bb4d4d3SMatthias Springer Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); 7310bb4d4d3SMatthias Springer if (!dialect) 7320bb4d4d3SMatthias Springer return emitOpError("unknown dialect or dialect not loaded: ") 7330bb4d4d3SMatthias Springer << getDialectName(); 734153661dbSMehdi Amini auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); 7350bb4d4d3SMatthias Springer if (!iface) 7360bb4d4d3SMatthias Springer return emitOpError( 7370bb4d4d3SMatthias Springer "dialect does not implement ConvertToLLVMPatternInterface or " 7380bb4d4d3SMatthias Springer "extension was not loaded: ") 7390bb4d4d3SMatthias Springer << getDialectName(); 7400bb4d4d3SMatthias Springer return success(); 7410bb4d4d3SMatthias Springer } 7420bb4d4d3SMatthias Springer 7430bb4d4d3SMatthias Springer //===----------------------------------------------------------------------===// 744fa1a23a7SMatthias Springer // ApplyLoopInvariantCodeMotionOp 745fa1a23a7SMatthias Springer //===----------------------------------------------------------------------===// 746fa1a23a7SMatthias Springer 747fa1a23a7SMatthias Springer DiagnosedSilenceableFailure 748fa1a23a7SMatthias Springer transform::ApplyLoopInvariantCodeMotionOp::applyToOne( 749fa1a23a7SMatthias Springer transform::TransformRewriter &rewriter, LoopLikeOpInterface target, 750fa1a23a7SMatthias Springer transform::ApplyToEachResultList &results, 751fa1a23a7SMatthias Springer transform::TransformState &state) { 752fa1a23a7SMatthias Springer // Currently, LICM does not remove operations, so we don't need tracking. 753fa1a23a7SMatthias Springer // If this ever changes, add a LICM entry point that takes a rewriter. 754fa1a23a7SMatthias Springer moveLoopInvariantCode(target); 755fa1a23a7SMatthias Springer return DiagnosedSilenceableFailure::success(); 756fa1a23a7SMatthias Springer } 757fa1a23a7SMatthias Springer 758fa1a23a7SMatthias Springer void transform::ApplyLoopInvariantCodeMotionOp::getEffects( 759fa1a23a7SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 7602c1ae801Sdonald chen transform::onlyReadsHandle(getTargetMutable(), effects); 761fa1a23a7SMatthias Springer transform::modifiesPayload(effects); 762fa1a23a7SMatthias Springer } 763fa1a23a7SMatthias Springer 764fa1a23a7SMatthias Springer //===----------------------------------------------------------------------===// 76518ec2030SMatthias Springer // ApplyRegisteredPassOp 76618ec2030SMatthias Springer //===----------------------------------------------------------------------===// 76718ec2030SMatthias Springer 768c63d2b2cSMatthias Springer DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne( 769c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 770c63d2b2cSMatthias Springer ApplyToEachResultList &results, transform::TransformState &state) { 77118ec2030SMatthias Springer // Make sure that this transform is not applied to itself. Modifying the 77218ec2030SMatthias Springer // transform IR while it is being interpreted is generally dangerous. Even 77318ec2030SMatthias Springer // more so when applying passes because they may perform a wide range of IR 77418ec2030SMatthias Springer // modifications. 77518ec2030SMatthias Springer DiagnosedSilenceableFailure payloadCheck = 77618ec2030SMatthias Springer ensurePayloadIsSeparateFromTransform(*this, target); 77718ec2030SMatthias Springer if (!payloadCheck.succeeded()) 77818ec2030SMatthias Springer return payloadCheck; 77918ec2030SMatthias Springer 7802f8690b1SMatthias Springer // Get pass or pass pipeline from registry. 7812f8690b1SMatthias Springer const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName()); 7822f8690b1SMatthias Springer if (!info) 7832f8690b1SMatthias Springer info = PassInfo::lookup(getPassName()); 7842f8690b1SMatthias Springer if (!info) 7852f8690b1SMatthias Springer return emitDefiniteFailure() 7862f8690b1SMatthias Springer << "unknown pass or pass pipeline: " << getPassName(); 78718ec2030SMatthias Springer 7882f8690b1SMatthias Springer // Create pass manager and run the pass or pass pipeline. 78918ec2030SMatthias Springer PassManager pm(getContext()); 7902f8690b1SMatthias Springer if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) { 79118ec2030SMatthias Springer emitError(msg); 79218ec2030SMatthias Springer return failure(); 79318ec2030SMatthias Springer }))) { 79418ec2030SMatthias Springer return emitDefiniteFailure() 7952f8690b1SMatthias Springer << "failed to add pass or pass pipeline to pipeline: " 7962f8690b1SMatthias Springer << getPassName(); 79718ec2030SMatthias Springer } 79818ec2030SMatthias Springer if (failed(pm.run(target))) { 79918ec2030SMatthias Springer auto diag = emitSilenceableError() << "pass pipeline failed"; 80018ec2030SMatthias Springer diag.attachNote(target->getLoc()) << "target op"; 80118ec2030SMatthias Springer return diag; 80218ec2030SMatthias Springer } 80318ec2030SMatthias Springer 80418ec2030SMatthias Springer results.push_back(target); 80518ec2030SMatthias Springer return DiagnosedSilenceableFailure::success(); 80618ec2030SMatthias Springer } 80718ec2030SMatthias Springer 80818ec2030SMatthias Springer //===----------------------------------------------------------------------===// 8092e5fe721SLorenzo Chelini // CastOp 810bffec215SMatthias Springer //===----------------------------------------------------------------------===// 811bffec215SMatthias Springer 812bffec215SMatthias Springer DiagnosedSilenceableFailure 813c63d2b2cSMatthias Springer transform::CastOp::applyToOne(transform::TransformRewriter &rewriter, 814c63d2b2cSMatthias Springer Operation *target, ApplyToEachResultList &results, 815bba85ebdSAlex Zinenko transform::TransformState &state) { 816bba85ebdSAlex Zinenko results.push_back(target); 817bba85ebdSAlex Zinenko return DiagnosedSilenceableFailure::success(); 818bba85ebdSAlex Zinenko } 819bba85ebdSAlex Zinenko 820bba85ebdSAlex Zinenko void transform::CastOp::getEffects( 821bba85ebdSAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 822bba85ebdSAlex Zinenko onlyReadsPayload(effects); 8232c1ae801Sdonald chen onlyReadsHandle(getInputMutable(), effects); 8242c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 825bba85ebdSAlex Zinenko } 826bba85ebdSAlex Zinenko 827bba85ebdSAlex Zinenko bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 828bba85ebdSAlex Zinenko assert(inputs.size() == 1 && "expected one input"); 829bba85ebdSAlex Zinenko assert(outputs.size() == 1 && "expected one output"); 830bba85ebdSAlex Zinenko return llvm::all_of( 831bba85ebdSAlex Zinenko std::initializer_list<Type>{inputs.front(), outputs.front()}, 832971b8525SJakub Kuderski llvm::IsaPred<transform::TransformHandleTypeInterface>); 833bba85ebdSAlex Zinenko } 834bba85ebdSAlex Zinenko 835bba85ebdSAlex Zinenko //===----------------------------------------------------------------------===// 836633d9184SOleksandr "Alex" Zinenko // CollectMatchingOp 83763c9d2b1SAlex Zinenko //===----------------------------------------------------------------------===// 83863c9d2b1SAlex Zinenko 839e4b04b39SOleksandr "Alex" Zinenko /// Applies matcher operations from the given `block` using 840e4b04b39SOleksandr "Alex" Zinenko /// `blockArgumentMapping` to initialize block arguments. Updates `state` 841e4b04b39SOleksandr "Alex" Zinenko /// accordingly. If any of the matcher produces a silenceable failure, discards 842e4b04b39SOleksandr "Alex" Zinenko /// it (printing the content to the debug output stream) and returns failure. If 843e4b04b39SOleksandr "Alex" Zinenko /// any of the matchers produces a definite failure, reports it and returns 844e4b04b39SOleksandr "Alex" Zinenko /// failure. If all matchers in the block succeed, populates `mappings` with the 845e4b04b39SOleksandr "Alex" Zinenko /// payload entities associated with the block terminator operands. Note that 846e4b04b39SOleksandr "Alex" Zinenko /// `mappings` will be cleared before that. 84763c9d2b1SAlex Zinenko static DiagnosedSilenceableFailure 848e4b04b39SOleksandr "Alex" Zinenko matchBlock(Block &block, 849e4b04b39SOleksandr "Alex" Zinenko ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping, 850e4b04b39SOleksandr "Alex" Zinenko transform::TransformState &state, 85163c9d2b1SAlex Zinenko SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) { 85263c9d2b1SAlex Zinenko assert(block.getParent() && "cannot match using a detached block"); 85322259281SMatthias Springer auto matchScope = state.make_region_scope(*block.getParent()); 854e4b04b39SOleksandr "Alex" Zinenko if (failed( 855e4b04b39SOleksandr "Alex" Zinenko state.mapBlockArguments(block.getArguments(), blockArgumentMapping))) 85663c9d2b1SAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 85763c9d2b1SAlex Zinenko 85863c9d2b1SAlex Zinenko for (Operation &match : block.without_terminator()) { 85963c9d2b1SAlex Zinenko if (!isa<transform::MatchOpInterface>(match)) { 86063c9d2b1SAlex Zinenko return emitDefiniteFailure(match.getLoc()) 86163c9d2b1SAlex Zinenko << "expected operations in the match part to " 86263c9d2b1SAlex Zinenko "implement MatchOpInterface"; 86363c9d2b1SAlex Zinenko } 86463c9d2b1SAlex Zinenko DiagnosedSilenceableFailure diag = 86563c9d2b1SAlex Zinenko state.applyTransform(cast<transform::TransformOpInterface>(match)); 86663c9d2b1SAlex Zinenko if (diag.succeeded()) 86763c9d2b1SAlex Zinenko continue; 86863c9d2b1SAlex Zinenko 86963c9d2b1SAlex Zinenko return diag; 87063c9d2b1SAlex Zinenko } 87163c9d2b1SAlex Zinenko 87263c9d2b1SAlex Zinenko // Remember the values mapped to the terminator operands so we can 87363c9d2b1SAlex Zinenko // forward them to the action. 87463c9d2b1SAlex Zinenko ValueRange yieldedValues = block.getTerminator()->getOperands(); 875e4b04b39SOleksandr "Alex" Zinenko // Our contract with the caller is that the mappings will contain only the 876e4b04b39SOleksandr "Alex" Zinenko // newly mapped values, clear the rest. 877e4b04b39SOleksandr "Alex" Zinenko mappings.clear(); 87863c9d2b1SAlex Zinenko transform::detail::prepareValueMappings(mappings, yieldedValues, state); 87963c9d2b1SAlex Zinenko return DiagnosedSilenceableFailure::success(); 88063c9d2b1SAlex Zinenko } 88163c9d2b1SAlex Zinenko 882633d9184SOleksandr "Alex" Zinenko /// Returns `true` if both types implement one of the interfaces provided as 883633d9184SOleksandr "Alex" Zinenko /// template parameters. 884633d9184SOleksandr "Alex" Zinenko template <typename... Tys> 885633d9184SOleksandr "Alex" Zinenko static bool implementSameInterface(Type t1, Type t2) { 886633d9184SOleksandr "Alex" Zinenko return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false); 887633d9184SOleksandr "Alex" Zinenko } 888633d9184SOleksandr "Alex" Zinenko 889633d9184SOleksandr "Alex" Zinenko /// Returns `true` if both types implement one of the transform dialect 890633d9184SOleksandr "Alex" Zinenko /// interfaces. 891633d9184SOleksandr "Alex" Zinenko static bool implementSameTransformInterface(Type t1, Type t2) { 892633d9184SOleksandr "Alex" Zinenko return implementSameInterface<transform::TransformHandleTypeInterface, 893633d9184SOleksandr "Alex" Zinenko transform::TransformParamTypeInterface, 894633d9184SOleksandr "Alex" Zinenko transform::TransformValueHandleTypeInterface>( 895633d9184SOleksandr "Alex" Zinenko t1, t2); 896633d9184SOleksandr "Alex" Zinenko } 897633d9184SOleksandr "Alex" Zinenko 898633d9184SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 899633d9184SOleksandr "Alex" Zinenko // CollectMatchingOp 900633d9184SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 901633d9184SOleksandr "Alex" Zinenko 902633d9184SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure 903633d9184SOleksandr "Alex" Zinenko transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, 904633d9184SOleksandr "Alex" Zinenko transform::TransformResults &results, 905633d9184SOleksandr "Alex" Zinenko transform::TransformState &state) { 906633d9184SOleksandr "Alex" Zinenko auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>( 907633d9184SOleksandr "Alex" Zinenko getOperation(), getMatcher()); 908633d9184SOleksandr "Alex" Zinenko if (matcher.isExternal()) { 909633d9184SOleksandr "Alex" Zinenko return emitDefiniteFailure() 910633d9184SOleksandr "Alex" Zinenko << "unresolved external symbol " << getMatcher(); 911633d9184SOleksandr "Alex" Zinenko } 912633d9184SOleksandr "Alex" Zinenko 913633d9184SOleksandr "Alex" Zinenko SmallVector<SmallVector<MappedValue>, 2> rawResults; 914633d9184SOleksandr "Alex" Zinenko rawResults.resize(getOperation()->getNumResults()); 915633d9184SOleksandr "Alex" Zinenko std::optional<DiagnosedSilenceableFailure> maybeFailure; 916633d9184SOleksandr "Alex" Zinenko for (Operation *root : state.getPayloadOps(getRoot())) { 917633d9184SOleksandr "Alex" Zinenko WalkResult walkResult = root->walk([&](Operation *op) { 918633d9184SOleksandr "Alex" Zinenko DEBUG_MATCHER({ 919633d9184SOleksandr "Alex" Zinenko DBGS_MATCHER() << "matching "; 920633d9184SOleksandr "Alex" Zinenko op->print(llvm::dbgs(), 921633d9184SOleksandr "Alex" Zinenko OpPrintingFlags().assumeVerified().skipRegions()); 922633d9184SOleksandr "Alex" Zinenko llvm::dbgs() << " @" << op << "\n"; 923633d9184SOleksandr "Alex" Zinenko }); 924633d9184SOleksandr "Alex" Zinenko 925633d9184SOleksandr "Alex" Zinenko // Try matching. 926633d9184SOleksandr "Alex" Zinenko SmallVector<SmallVector<MappedValue>> mappings; 927e4b04b39SOleksandr "Alex" Zinenko SmallVector<transform::MappedValue> inputMapping({op}); 928e4b04b39SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure diag = matchBlock( 929e4b04b39SOleksandr "Alex" Zinenko matcher.getFunctionBody().front(), 930e4b04b39SOleksandr "Alex" Zinenko ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state, 931e4b04b39SOleksandr "Alex" Zinenko mappings); 932633d9184SOleksandr "Alex" Zinenko if (diag.isDefiniteFailure()) 933633d9184SOleksandr "Alex" Zinenko return WalkResult::interrupt(); 934633d9184SOleksandr "Alex" Zinenko if (diag.isSilenceableFailure()) { 935633d9184SOleksandr "Alex" Zinenko DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() 936633d9184SOleksandr "Alex" Zinenko << " failed: " << diag.getMessage()); 937633d9184SOleksandr "Alex" Zinenko return WalkResult::advance(); 938633d9184SOleksandr "Alex" Zinenko } 939633d9184SOleksandr "Alex" Zinenko 940633d9184SOleksandr "Alex" Zinenko // If succeeded, collect results. 941633d9184SOleksandr "Alex" Zinenko for (auto &&[i, mapping] : llvm::enumerate(mappings)) { 942633d9184SOleksandr "Alex" Zinenko if (mapping.size() != 1) { 943633d9184SOleksandr "Alex" Zinenko maybeFailure.emplace(emitSilenceableError() 944633d9184SOleksandr "Alex" Zinenko << "result #" << i << ", associated with " 945633d9184SOleksandr "Alex" Zinenko << mapping.size() 946633d9184SOleksandr "Alex" Zinenko << " payload objects, expected 1"); 947633d9184SOleksandr "Alex" Zinenko return WalkResult::interrupt(); 948633d9184SOleksandr "Alex" Zinenko } 949633d9184SOleksandr "Alex" Zinenko rawResults[i].push_back(mapping[0]); 950633d9184SOleksandr "Alex" Zinenko } 951633d9184SOleksandr "Alex" Zinenko return WalkResult::advance(); 952633d9184SOleksandr "Alex" Zinenko }); 953633d9184SOleksandr "Alex" Zinenko if (walkResult.wasInterrupted()) 954633d9184SOleksandr "Alex" Zinenko return std::move(*maybeFailure); 955633d9184SOleksandr "Alex" Zinenko assert(!maybeFailure && "failure set but the walk was not interrupted"); 956633d9184SOleksandr "Alex" Zinenko 957633d9184SOleksandr "Alex" Zinenko for (auto &&[opResult, rawResult] : 958633d9184SOleksandr "Alex" Zinenko llvm::zip_equal(getOperation()->getResults(), rawResults)) { 959633d9184SOleksandr "Alex" Zinenko results.setMappedValues(opResult, rawResult); 960633d9184SOleksandr "Alex" Zinenko } 961633d9184SOleksandr "Alex" Zinenko } 962633d9184SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::success(); 963633d9184SOleksandr "Alex" Zinenko } 964633d9184SOleksandr "Alex" Zinenko 965633d9184SOleksandr "Alex" Zinenko void transform::CollectMatchingOp::getEffects( 966633d9184SOleksandr "Alex" Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 9672c1ae801Sdonald chen onlyReadsHandle(getRootMutable(), effects); 9682c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 969633d9184SOleksandr "Alex" Zinenko onlyReadsPayload(effects); 970633d9184SOleksandr "Alex" Zinenko } 971633d9184SOleksandr "Alex" Zinenko 972633d9184SOleksandr "Alex" Zinenko LogicalResult transform::CollectMatchingOp::verifySymbolUses( 973633d9184SOleksandr "Alex" Zinenko SymbolTableCollection &symbolTable) { 974633d9184SOleksandr "Alex" Zinenko auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( 975633d9184SOleksandr "Alex" Zinenko symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher())); 976633d9184SOleksandr "Alex" Zinenko if (!matcherSymbol || 977633d9184SOleksandr "Alex" Zinenko !isa<TransformOpInterface>(matcherSymbol.getOperation())) 978633d9184SOleksandr "Alex" Zinenko return emitError() << "unresolved matcher symbol " << getMatcher(); 979633d9184SOleksandr "Alex" Zinenko 980633d9184SOleksandr "Alex" Zinenko ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes(); 981633d9184SOleksandr "Alex" Zinenko if (argumentTypes.size() != 1 || 982633d9184SOleksandr "Alex" Zinenko !isa<TransformHandleTypeInterface>(argumentTypes[0])) { 983633d9184SOleksandr "Alex" Zinenko return emitError() 984633d9184SOleksandr "Alex" Zinenko << "expected the matcher to take one operation handle argument"; 985633d9184SOleksandr "Alex" Zinenko } 986633d9184SOleksandr "Alex" Zinenko if (!matcherSymbol.getArgAttr( 987633d9184SOleksandr "Alex" Zinenko 0, transform::TransformDialect::kArgReadOnlyAttrName)) { 988633d9184SOleksandr "Alex" Zinenko return emitError() << "expected the matcher argument to be marked readonly"; 989633d9184SOleksandr "Alex" Zinenko } 990633d9184SOleksandr "Alex" Zinenko 991633d9184SOleksandr "Alex" Zinenko ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes(); 992633d9184SOleksandr "Alex" Zinenko if (resultTypes.size() != getOperation()->getNumResults()) { 993633d9184SOleksandr "Alex" Zinenko return emitError() 994633d9184SOleksandr "Alex" Zinenko << "expected the matcher to yield as many values as op has results (" 995633d9184SOleksandr "Alex" Zinenko << getOperation()->getNumResults() << "), got " 996633d9184SOleksandr "Alex" Zinenko << resultTypes.size(); 997633d9184SOleksandr "Alex" Zinenko } 998633d9184SOleksandr "Alex" Zinenko 999633d9184SOleksandr "Alex" Zinenko for (auto &&[i, matcherType, resultType] : 1000633d9184SOleksandr "Alex" Zinenko llvm::enumerate(resultTypes, getOperation()->getResultTypes())) { 1001633d9184SOleksandr "Alex" Zinenko if (implementSameTransformInterface(matcherType, resultType)) 1002633d9184SOleksandr "Alex" Zinenko continue; 1003633d9184SOleksandr "Alex" Zinenko 1004633d9184SOleksandr "Alex" Zinenko return emitError() 1005633d9184SOleksandr "Alex" Zinenko << "mismatching type interfaces for matcher result and op result #" 1006633d9184SOleksandr "Alex" Zinenko << i; 1007633d9184SOleksandr "Alex" Zinenko } 1008633d9184SOleksandr "Alex" Zinenko 1009633d9184SOleksandr "Alex" Zinenko return success(); 1010633d9184SOleksandr "Alex" Zinenko } 1011633d9184SOleksandr "Alex" Zinenko 1012633d9184SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 1013633d9184SOleksandr "Alex" Zinenko // ForeachMatchOp 1014633d9184SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 1015633d9184SOleksandr "Alex" Zinenko 1016e4b04b39SOleksandr "Alex" Zinenko // This is fine because nothing is actually consumed by this op. 1017e4b04b39SOleksandr "Alex" Zinenko bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; } 1018e4b04b39SOleksandr "Alex" Zinenko 101963c9d2b1SAlex Zinenko DiagnosedSilenceableFailure 1020c63d2b2cSMatthias Springer transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, 1021c63d2b2cSMatthias Springer transform::TransformResults &results, 102263c9d2b1SAlex Zinenko transform::TransformState &state) { 102363c9d2b1SAlex Zinenko SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>> 102463c9d2b1SAlex Zinenko matchActionPairs; 102563c9d2b1SAlex Zinenko matchActionPairs.reserve(getMatchers().size()); 102663c9d2b1SAlex Zinenko SymbolTableCollection symbolTable; 102763c9d2b1SAlex Zinenko for (auto &&[matcher, action] : 102863c9d2b1SAlex Zinenko llvm::zip_equal(getMatchers(), getActions())) { 102963c9d2b1SAlex Zinenko auto matcherSymbol = 103063c9d2b1SAlex Zinenko symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>( 103163c9d2b1SAlex Zinenko getOperation(), cast<SymbolRefAttr>(matcher)); 103263c9d2b1SAlex Zinenko auto actionSymbol = 103363c9d2b1SAlex Zinenko symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>( 103463c9d2b1SAlex Zinenko getOperation(), cast<SymbolRefAttr>(action)); 103563c9d2b1SAlex Zinenko assert(matcherSymbol && actionSymbol && 103663c9d2b1SAlex Zinenko "unresolved symbols not caught by the verifier"); 103763c9d2b1SAlex Zinenko 103863c9d2b1SAlex Zinenko if (matcherSymbol.isExternal()) 103963c9d2b1SAlex Zinenko return emitDefiniteFailure() << "unresolved external symbol " << matcher; 104063c9d2b1SAlex Zinenko if (actionSymbol.isExternal()) 104163c9d2b1SAlex Zinenko return emitDefiniteFailure() << "unresolved external symbol " << action; 104263c9d2b1SAlex Zinenko 104363c9d2b1SAlex Zinenko matchActionPairs.emplace_back(matcherSymbol, actionSymbol); 104463c9d2b1SAlex Zinenko } 104563c9d2b1SAlex Zinenko 10460b790572SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure overallDiag = 10470b790572SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure::success(); 1048e4b04b39SOleksandr "Alex" Zinenko 1049e4b04b39SOleksandr "Alex" Zinenko SmallVector<SmallVector<MappedValue>> matchInputMapping; 1050e4b04b39SOleksandr "Alex" Zinenko SmallVector<SmallVector<MappedValue>> matchOutputMapping; 1051e4b04b39SOleksandr "Alex" Zinenko SmallVector<SmallVector<MappedValue>> actionResultMapping; 1052e4b04b39SOleksandr "Alex" Zinenko // Explicitly add the mapping for the first block argument (the op being 1053e4b04b39SOleksandr "Alex" Zinenko // matched). 1054e4b04b39SOleksandr "Alex" Zinenko matchInputMapping.emplace_back(); 1055e4b04b39SOleksandr "Alex" Zinenko transform::detail::prepareValueMappings(matchInputMapping, 1056e4b04b39SOleksandr "Alex" Zinenko getForwardedInputs(), state); 1057e4b04b39SOleksandr "Alex" Zinenko SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front(); 1058e4b04b39SOleksandr "Alex" Zinenko actionResultMapping.resize(getForwardedOutputs().size()); 1059e4b04b39SOleksandr "Alex" Zinenko 106063c9d2b1SAlex Zinenko for (Operation *root : state.getPayloadOps(getRoot())) { 106163c9d2b1SAlex Zinenko WalkResult walkResult = root->walk([&](Operation *op) { 10628483d18bSNicolas Vasilache // If getRestrictRoot is not present, skip over the root op itself so we 10638483d18bSNicolas Vasilache // don't invalidate it. 10648483d18bSNicolas Vasilache if (!getRestrictRoot() && op == root) 106563c9d2b1SAlex Zinenko return WalkResult::advance(); 106663c9d2b1SAlex Zinenko 106763c9d2b1SAlex Zinenko DEBUG_MATCHER({ 106863c9d2b1SAlex Zinenko DBGS_MATCHER() << "matching "; 106963c9d2b1SAlex Zinenko op->print(llvm::dbgs(), 107063c9d2b1SAlex Zinenko OpPrintingFlags().assumeVerified().skipRegions()); 107163c9d2b1SAlex Zinenko llvm::dbgs() << " @" << op << "\n"; 107263c9d2b1SAlex Zinenko }); 107363c9d2b1SAlex Zinenko 1074e4b04b39SOleksandr "Alex" Zinenko firstMatchArgument.clear(); 1075e4b04b39SOleksandr "Alex" Zinenko firstMatchArgument.push_back(op); 1076e4b04b39SOleksandr "Alex" Zinenko 107763c9d2b1SAlex Zinenko // Try all the match/action pairs until the first successful match. 107863c9d2b1SAlex Zinenko for (auto [matcher, action] : matchActionPairs) { 107963c9d2b1SAlex Zinenko DiagnosedSilenceableFailure diag = 1080e4b04b39SOleksandr "Alex" Zinenko matchBlock(matcher.getFunctionBody().front(), matchInputMapping, 1081e4b04b39SOleksandr "Alex" Zinenko state, matchOutputMapping); 108263c9d2b1SAlex Zinenko if (diag.isDefiniteFailure()) 108363c9d2b1SAlex Zinenko return WalkResult::interrupt(); 108463c9d2b1SAlex Zinenko if (diag.isSilenceableFailure()) { 10853fe7127dSAlex Zinenko DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() 10863fe7127dSAlex Zinenko << " failed: " << diag.getMessage()); 108763c9d2b1SAlex Zinenko continue; 108863c9d2b1SAlex Zinenko } 108963c9d2b1SAlex Zinenko 109022259281SMatthias Springer auto scope = state.make_region_scope(action.getFunctionBody()); 1091e4b04b39SOleksandr "Alex" Zinenko if (failed(state.mapBlockArguments( 1092e4b04b39SOleksandr "Alex" Zinenko action.getFunctionBody().front().getArguments(), 1093e4b04b39SOleksandr "Alex" Zinenko matchOutputMapping))) { 109463c9d2b1SAlex Zinenko return WalkResult::interrupt(); 109563c9d2b1SAlex Zinenko } 109663c9d2b1SAlex Zinenko 109763c9d2b1SAlex Zinenko for (Operation &transform : 109863c9d2b1SAlex Zinenko action.getFunctionBody().front().without_terminator()) { 109963c9d2b1SAlex Zinenko DiagnosedSilenceableFailure result = 110063c9d2b1SAlex Zinenko state.applyTransform(cast<TransformOpInterface>(transform)); 11010b790572SOleksandr "Alex" Zinenko if (result.isDefiniteFailure()) 110263c9d2b1SAlex Zinenko return WalkResult::interrupt(); 11030b790572SOleksandr "Alex" Zinenko if (result.isSilenceableFailure()) { 11040b790572SOleksandr "Alex" Zinenko if (overallDiag.succeeded()) { 11050b790572SOleksandr "Alex" Zinenko overallDiag = emitSilenceableError() << "actions failed"; 11060b790572SOleksandr "Alex" Zinenko } 11070b790572SOleksandr "Alex" Zinenko overallDiag.attachNote(action->getLoc()) 11080b790572SOleksandr "Alex" Zinenko << "failed action: " << result.getMessage(); 11090b790572SOleksandr "Alex" Zinenko overallDiag.attachNote(op->getLoc()) 11100b790572SOleksandr "Alex" Zinenko << "when applied to this matching payload"; 11110b790572SOleksandr "Alex" Zinenko (void)result.silence(); 11120b790572SOleksandr "Alex" Zinenko continue; 11130b790572SOleksandr "Alex" Zinenko } 111463c9d2b1SAlex Zinenko } 1115e4b04b39SOleksandr "Alex" Zinenko if (failed(detail::appendValueMappings( 1116e4b04b39SOleksandr "Alex" Zinenko MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping), 1117e4b04b39SOleksandr "Alex" Zinenko action.getFunctionBody().front().getTerminator()->getOperands(), 1118e4b04b39SOleksandr "Alex" Zinenko state, getFlattenResults()))) { 1119e4b04b39SOleksandr "Alex" Zinenko emitDefiniteFailure() 1120e4b04b39SOleksandr "Alex" Zinenko << "action @" << action.getName() 1121e4b04b39SOleksandr "Alex" Zinenko << " has results associated with multiple payload entities, " 1122e4b04b39SOleksandr "Alex" Zinenko "but flattening was not requested"; 1123e4b04b39SOleksandr "Alex" Zinenko return WalkResult::interrupt(); 1124e4b04b39SOleksandr "Alex" Zinenko } 112563c9d2b1SAlex Zinenko break; 112663c9d2b1SAlex Zinenko } 112763c9d2b1SAlex Zinenko return WalkResult::advance(); 112863c9d2b1SAlex Zinenko }); 112963c9d2b1SAlex Zinenko if (walkResult.wasInterrupted()) 113063c9d2b1SAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 113163c9d2b1SAlex Zinenko } 113263c9d2b1SAlex Zinenko 113363c9d2b1SAlex Zinenko // The root operation should not have been affected, so we can just reassign 113463c9d2b1SAlex Zinenko // the payload to the result. Note that we need to consume the root handle to 113563c9d2b1SAlex Zinenko // make sure any handles to operations inside, that could have been affected 113663c9d2b1SAlex Zinenko // by actions, are invalidated. 1137c1fa60b4STres Popp results.set(llvm::cast<OpResult>(getUpdated()), 1138c1fa60b4STres Popp state.getPayloadOps(getRoot())); 1139e4b04b39SOleksandr "Alex" Zinenko for (auto &&[result, mapping] : 1140e4b04b39SOleksandr "Alex" Zinenko llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) { 1141e4b04b39SOleksandr "Alex" Zinenko results.setMappedValues(result, mapping); 1142e4b04b39SOleksandr "Alex" Zinenko } 11430b790572SOleksandr "Alex" Zinenko return overallDiag; 114463c9d2b1SAlex Zinenko } 114563c9d2b1SAlex Zinenko 1146e4b04b39SOleksandr "Alex" Zinenko void transform::ForeachMatchOp::getAsmResultNames( 1147e4b04b39SOleksandr "Alex" Zinenko OpAsmSetValueNameFn setNameFn) { 1148e4b04b39SOleksandr "Alex" Zinenko setNameFn(getUpdated(), "updated_root"); 1149e4b04b39SOleksandr "Alex" Zinenko for (Value v : getForwardedOutputs()) { 1150e4b04b39SOleksandr "Alex" Zinenko setNameFn(v, "yielded"); 1151e4b04b39SOleksandr "Alex" Zinenko } 1152e4b04b39SOleksandr "Alex" Zinenko } 1153e4b04b39SOleksandr "Alex" Zinenko 115463c9d2b1SAlex Zinenko void transform::ForeachMatchOp::getEffects( 115563c9d2b1SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 115663c9d2b1SAlex Zinenko // Bail if invalid. 115763c9d2b1SAlex Zinenko if (getOperation()->getNumOperands() < 1 || 115863c9d2b1SAlex Zinenko getOperation()->getNumResults() < 1) { 115963c9d2b1SAlex Zinenko return modifiesPayload(effects); 116063c9d2b1SAlex Zinenko } 116163c9d2b1SAlex Zinenko 11622c1ae801Sdonald chen consumesHandle(getRootMutable(), effects); 11632c1ae801Sdonald chen onlyReadsHandle(getForwardedInputsMutable(), effects); 11642c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 116563c9d2b1SAlex Zinenko modifiesPayload(effects); 116663c9d2b1SAlex Zinenko } 116763c9d2b1SAlex Zinenko 116863c9d2b1SAlex Zinenko /// Parses the comma-separated list of symbol reference pairs of the format 116963c9d2b1SAlex Zinenko /// `@matcher -> @action`. 117063c9d2b1SAlex Zinenko static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, 117163c9d2b1SAlex Zinenko ArrayAttr &matchers, 117263c9d2b1SAlex Zinenko ArrayAttr &actions) { 117363c9d2b1SAlex Zinenko StringAttr matcher; 117463c9d2b1SAlex Zinenko StringAttr action; 117563c9d2b1SAlex Zinenko SmallVector<Attribute> matcherList; 117663c9d2b1SAlex Zinenko SmallVector<Attribute> actionList; 117763c9d2b1SAlex Zinenko do { 117863c9d2b1SAlex Zinenko if (parser.parseSymbolName(matcher) || parser.parseArrow() || 117963c9d2b1SAlex Zinenko parser.parseSymbolName(action)) { 118063c9d2b1SAlex Zinenko return failure(); 118163c9d2b1SAlex Zinenko } 118263c9d2b1SAlex Zinenko matcherList.push_back(SymbolRefAttr::get(matcher)); 118363c9d2b1SAlex Zinenko actionList.push_back(SymbolRefAttr::get(action)); 118463c9d2b1SAlex Zinenko } while (parser.parseOptionalComma().succeeded()); 118563c9d2b1SAlex Zinenko 118663c9d2b1SAlex Zinenko matchers = parser.getBuilder().getArrayAttr(matcherList); 118763c9d2b1SAlex Zinenko actions = parser.getBuilder().getArrayAttr(actionList); 118863c9d2b1SAlex Zinenko return success(); 118963c9d2b1SAlex Zinenko } 119063c9d2b1SAlex Zinenko 119163c9d2b1SAlex Zinenko /// Prints the comma-separated list of symbol reference pairs of the format 119263c9d2b1SAlex Zinenko /// `@matcher -> @action`. 119363c9d2b1SAlex Zinenko static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, 119463c9d2b1SAlex Zinenko ArrayAttr matchers, ArrayAttr actions) { 119563c9d2b1SAlex Zinenko printer.increaseIndent(); 119663c9d2b1SAlex Zinenko printer.increaseIndent(); 119763c9d2b1SAlex Zinenko for (auto &&[matcher, action, idx] : llvm::zip_equal( 119863c9d2b1SAlex Zinenko matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) { 119963c9d2b1SAlex Zinenko printer.printNewline(); 120063c9d2b1SAlex Zinenko printer << cast<SymbolRefAttr>(matcher) << " -> " 120163c9d2b1SAlex Zinenko << cast<SymbolRefAttr>(action); 120263c9d2b1SAlex Zinenko if (idx != matchers.size() - 1) 120363c9d2b1SAlex Zinenko printer << ", "; 120463c9d2b1SAlex Zinenko } 120563c9d2b1SAlex Zinenko printer.decreaseIndent(); 120663c9d2b1SAlex Zinenko printer.decreaseIndent(); 120763c9d2b1SAlex Zinenko } 120863c9d2b1SAlex Zinenko 120963c9d2b1SAlex Zinenko LogicalResult transform::ForeachMatchOp::verify() { 121063c9d2b1SAlex Zinenko if (getMatchers().size() != getActions().size()) 121163c9d2b1SAlex Zinenko return emitOpError() << "expected the same number of matchers and actions"; 121263c9d2b1SAlex Zinenko if (getMatchers().empty()) 121363c9d2b1SAlex Zinenko return emitOpError() << "expected at least one match/action pair"; 121463c9d2b1SAlex Zinenko 121563c9d2b1SAlex Zinenko llvm::SmallPtrSet<Attribute, 8> matcherNames; 121663c9d2b1SAlex Zinenko for (Attribute name : getMatchers()) { 121763c9d2b1SAlex Zinenko if (matcherNames.insert(name).second) 121863c9d2b1SAlex Zinenko continue; 121963c9d2b1SAlex Zinenko emitWarning() << "matcher " << name 122063c9d2b1SAlex Zinenko << " is used more than once, only the first match will apply"; 122163c9d2b1SAlex Zinenko } 122263c9d2b1SAlex Zinenko 122363c9d2b1SAlex Zinenko return success(); 122463c9d2b1SAlex Zinenko } 122563c9d2b1SAlex Zinenko 122663c9d2b1SAlex Zinenko /// Checks that the attributes of the function-like operation have correct 122763c9d2b1SAlex Zinenko /// consumption effect annotations. If `alsoVerifyInternal`, checks for 122863c9d2b1SAlex Zinenko /// annotations being present even if they can be inferred from the body. 122963c9d2b1SAlex Zinenko static DiagnosedSilenceableFailure 1230135e5bf8SAlex Zinenko verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, 123163c9d2b1SAlex Zinenko bool alsoVerifyInternal = false) { 123263c9d2b1SAlex Zinenko auto transformOp = cast<transform::TransformOpInterface>(op.getOperation()); 123363c9d2b1SAlex Zinenko llvm::SmallDenseSet<unsigned> consumedArguments; 123463c9d2b1SAlex Zinenko if (!op.isExternal()) { 123563c9d2b1SAlex Zinenko transform::getConsumedBlockArguments(op.getFunctionBody().front(), 123663c9d2b1SAlex Zinenko consumedArguments); 123763c9d2b1SAlex Zinenko } 123863c9d2b1SAlex Zinenko for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) { 123963c9d2b1SAlex Zinenko bool isConsumed = 124063c9d2b1SAlex Zinenko op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) != 124163c9d2b1SAlex Zinenko nullptr; 124263c9d2b1SAlex Zinenko bool isReadOnly = 124363c9d2b1SAlex Zinenko op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) != 124463c9d2b1SAlex Zinenko nullptr; 124563c9d2b1SAlex Zinenko if (isConsumed && isReadOnly) { 124663c9d2b1SAlex Zinenko return transformOp.emitSilenceableError() 124763c9d2b1SAlex Zinenko << "argument #" << i << " cannot be both readonly and consumed"; 124863c9d2b1SAlex Zinenko } 124963c9d2b1SAlex Zinenko if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) { 125063c9d2b1SAlex Zinenko return transformOp.emitSilenceableError() 125163c9d2b1SAlex Zinenko << "must provide consumed/readonly status for arguments of " 125263c9d2b1SAlex Zinenko "external or called ops"; 125363c9d2b1SAlex Zinenko } 125463c9d2b1SAlex Zinenko if (op.isExternal()) 125563c9d2b1SAlex Zinenko continue; 125663c9d2b1SAlex Zinenko 125763c9d2b1SAlex Zinenko if (consumedArguments.contains(i) && !isConsumed && isReadOnly) { 125863c9d2b1SAlex Zinenko return transformOp.emitSilenceableError() 125963c9d2b1SAlex Zinenko << "argument #" << i 126063c9d2b1SAlex Zinenko << " is consumed in the body but is not marked as such"; 126163c9d2b1SAlex Zinenko } 1262135e5bf8SAlex Zinenko if (emitWarnings && !consumedArguments.contains(i) && isConsumed) { 1263135e5bf8SAlex Zinenko // Cannot use op.emitWarning() here as it would attempt to verify the op 1264135e5bf8SAlex Zinenko // before printing, resulting in infinite recursion. 1265135e5bf8SAlex Zinenko emitWarning(op->getLoc()) 1266135e5bf8SAlex Zinenko << "op argument #" << i 126763c9d2b1SAlex Zinenko << " is not consumed in the body but is marked as consumed"; 126863c9d2b1SAlex Zinenko } 126963c9d2b1SAlex Zinenko } 127063c9d2b1SAlex Zinenko return DiagnosedSilenceableFailure::success(); 127163c9d2b1SAlex Zinenko } 127263c9d2b1SAlex Zinenko 127363c9d2b1SAlex Zinenko LogicalResult transform::ForeachMatchOp::verifySymbolUses( 127463c9d2b1SAlex Zinenko SymbolTableCollection &symbolTable) { 127563c9d2b1SAlex Zinenko assert(getMatchers().size() == getActions().size()); 127663c9d2b1SAlex Zinenko auto consumedAttr = 127763c9d2b1SAlex Zinenko StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName); 127863c9d2b1SAlex Zinenko for (auto &&[matcher, action] : 127963c9d2b1SAlex Zinenko llvm::zip_equal(getMatchers(), getActions())) { 1280e4b04b39SOleksandr "Alex" Zinenko // Presence and typing. 128163c9d2b1SAlex Zinenko auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( 128263c9d2b1SAlex Zinenko symbolTable.lookupNearestSymbolFrom(getOperation(), 128363c9d2b1SAlex Zinenko cast<SymbolRefAttr>(matcher))); 128463c9d2b1SAlex Zinenko auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>( 128563c9d2b1SAlex Zinenko symbolTable.lookupNearestSymbolFrom(getOperation(), 128663c9d2b1SAlex Zinenko cast<SymbolRefAttr>(action))); 128763c9d2b1SAlex Zinenko if (!matcherSymbol || 128863c9d2b1SAlex Zinenko !isa<TransformOpInterface>(matcherSymbol.getOperation())) 128963c9d2b1SAlex Zinenko return emitError() << "unresolved matcher symbol " << matcher; 129063c9d2b1SAlex Zinenko if (!actionSymbol || 129163c9d2b1SAlex Zinenko !isa<TransformOpInterface>(actionSymbol.getOperation())) 129263c9d2b1SAlex Zinenko return emitError() << "unresolved action symbol " << action; 129363c9d2b1SAlex Zinenko 129463c9d2b1SAlex Zinenko if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol, 1295135e5bf8SAlex Zinenko /*emitWarnings=*/false, 129663c9d2b1SAlex Zinenko /*alsoVerifyInternal=*/true) 129763c9d2b1SAlex Zinenko .checkAndReport())) { 129863c9d2b1SAlex Zinenko return failure(); 129963c9d2b1SAlex Zinenko } 130063c9d2b1SAlex Zinenko if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol, 1301135e5bf8SAlex Zinenko /*emitWarnings=*/false, 130263c9d2b1SAlex Zinenko /*alsoVerifyInternal=*/true) 130363c9d2b1SAlex Zinenko .checkAndReport())) { 130463c9d2b1SAlex Zinenko return failure(); 130563c9d2b1SAlex Zinenko } 130663c9d2b1SAlex Zinenko 1307e4b04b39SOleksandr "Alex" Zinenko // Input -> matcher forwarding. 1308e4b04b39SOleksandr "Alex" Zinenko TypeRange operandTypes = getOperandTypes(); 1309e4b04b39SOleksandr "Alex" Zinenko TypeRange matcherArguments = matcherSymbol.getArgumentTypes(); 1310e4b04b39SOleksandr "Alex" Zinenko if (operandTypes.size() != matcherArguments.size()) { 1311e4b04b39SOleksandr "Alex" Zinenko InFlightDiagnostic diag = 1312e4b04b39SOleksandr "Alex" Zinenko emitError() << "the number of operands (" << operandTypes.size() 1313e4b04b39SOleksandr "Alex" Zinenko << ") doesn't match the number of matcher arguments (" 1314e4b04b39SOleksandr "Alex" Zinenko << matcherArguments.size() << ") for " << matcher; 1315e4b04b39SOleksandr "Alex" Zinenko diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; 1316e4b04b39SOleksandr "Alex" Zinenko return diag; 1317e4b04b39SOleksandr "Alex" Zinenko } 1318e4b04b39SOleksandr "Alex" Zinenko for (auto &&[i, operand, argument] : 1319e4b04b39SOleksandr "Alex" Zinenko llvm::enumerate(operandTypes, matcherArguments)) { 1320e4b04b39SOleksandr "Alex" Zinenko if (matcherSymbol.getArgAttr(i, consumedAttr)) { 1321e4b04b39SOleksandr "Alex" Zinenko InFlightDiagnostic diag = 1322e4b04b39SOleksandr "Alex" Zinenko emitOpError() 1323e4b04b39SOleksandr "Alex" Zinenko << "does not expect matcher symbol to consume its operand #" << i; 1324e4b04b39SOleksandr "Alex" Zinenko diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; 1325e4b04b39SOleksandr "Alex" Zinenko return diag; 1326e4b04b39SOleksandr "Alex" Zinenko } 1327e4b04b39SOleksandr "Alex" Zinenko 1328e4b04b39SOleksandr "Alex" Zinenko if (implementSameTransformInterface(operand, argument)) 1329e4b04b39SOleksandr "Alex" Zinenko continue; 1330e4b04b39SOleksandr "Alex" Zinenko 1331e4b04b39SOleksandr "Alex" Zinenko InFlightDiagnostic diag = 1332e4b04b39SOleksandr "Alex" Zinenko emitError() 1333e4b04b39SOleksandr "Alex" Zinenko << "mismatching type interfaces for operand and matcher argument #" 1334e4b04b39SOleksandr "Alex" Zinenko << i << " of matcher " << matcher; 1335e4b04b39SOleksandr "Alex" Zinenko diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; 1336e4b04b39SOleksandr "Alex" Zinenko return diag; 1337e4b04b39SOleksandr "Alex" Zinenko } 1338e4b04b39SOleksandr "Alex" Zinenko 1339e4b04b39SOleksandr "Alex" Zinenko // Matcher -> action forwarding. 1340e4b04b39SOleksandr "Alex" Zinenko TypeRange matcherResults = matcherSymbol.getResultTypes(); 1341e4b04b39SOleksandr "Alex" Zinenko TypeRange actionArguments = actionSymbol.getArgumentTypes(); 134263c9d2b1SAlex Zinenko if (matcherResults.size() != actionArguments.size()) { 134363c9d2b1SAlex Zinenko return emitError() << "mismatching number of matcher results and " 134463c9d2b1SAlex Zinenko "action arguments between " 134563c9d2b1SAlex Zinenko << matcher << " (" << matcherResults.size() << ") and " 134663c9d2b1SAlex Zinenko << action << " (" << actionArguments.size() << ")"; 134763c9d2b1SAlex Zinenko } 134863c9d2b1SAlex Zinenko for (auto &&[i, matcherType, actionType] : 134963c9d2b1SAlex Zinenko llvm::enumerate(matcherResults, actionArguments)) { 135063c9d2b1SAlex Zinenko if (implementSameTransformInterface(matcherType, actionType)) 135163c9d2b1SAlex Zinenko continue; 135263c9d2b1SAlex Zinenko 135363c9d2b1SAlex Zinenko return emitError() << "mismatching type interfaces for matcher result " 135463c9d2b1SAlex Zinenko "and action argument #" 1355e4b04b39SOleksandr "Alex" Zinenko << i << "of matcher " << matcher << " and action " 1356e4b04b39SOleksandr "Alex" Zinenko << action; 135763c9d2b1SAlex Zinenko } 135863c9d2b1SAlex Zinenko 1359e4b04b39SOleksandr "Alex" Zinenko // Action -> result forwarding. 1360e4b04b39SOleksandr "Alex" Zinenko TypeRange actionResults = actionSymbol.getResultTypes(); 1361e4b04b39SOleksandr "Alex" Zinenko auto resultTypes = TypeRange(getResultTypes()).drop_front(); 1362e4b04b39SOleksandr "Alex" Zinenko if (actionResults.size() != resultTypes.size()) { 136363c9d2b1SAlex Zinenko InFlightDiagnostic diag = 1364e4b04b39SOleksandr "Alex" Zinenko emitError() << "the number of action results (" 1365e4b04b39SOleksandr "Alex" Zinenko << actionResults.size() << ") for " << action 1366e4b04b39SOleksandr "Alex" Zinenko << " doesn't match the number of extra op results (" 1367e4b04b39SOleksandr "Alex" Zinenko << resultTypes.size() << ")"; 136863c9d2b1SAlex Zinenko diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; 136963c9d2b1SAlex Zinenko return diag; 137063c9d2b1SAlex Zinenko } 1371e4b04b39SOleksandr "Alex" Zinenko for (auto &&[i, resultType, actionType] : 1372e4b04b39SOleksandr "Alex" Zinenko llvm::enumerate(resultTypes, actionResults)) { 1373e4b04b39SOleksandr "Alex" Zinenko if (implementSameTransformInterface(resultType, actionType)) 1374e4b04b39SOleksandr "Alex" Zinenko continue; 137563c9d2b1SAlex Zinenko 137663c9d2b1SAlex Zinenko InFlightDiagnostic diag = 1377e4b04b39SOleksandr "Alex" Zinenko emitError() << "mismatching type interfaces for action result #" << i 1378e4b04b39SOleksandr "Alex" Zinenko << " of action " << action << " and op result"; 1379e4b04b39SOleksandr "Alex" Zinenko diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; 138063c9d2b1SAlex Zinenko return diag; 138163c9d2b1SAlex Zinenko } 138263c9d2b1SAlex Zinenko } 138363c9d2b1SAlex Zinenko return success(); 138463c9d2b1SAlex Zinenko } 138563c9d2b1SAlex Zinenko 138663c9d2b1SAlex Zinenko //===----------------------------------------------------------------------===// 1387bba85ebdSAlex Zinenko // ForeachOp 1388bba85ebdSAlex Zinenko //===----------------------------------------------------------------------===// 1389bba85ebdSAlex Zinenko 1390bba85ebdSAlex Zinenko DiagnosedSilenceableFailure 1391c63d2b2cSMatthias Springer transform::ForeachOp::apply(transform::TransformRewriter &rewriter, 1392c63d2b2cSMatthias Springer transform::TransformResults &results, 1393bffec215SMatthias Springer transform::TransformState &state) { 1394d462bf68SRolf Morel // We store the payloads before executing the body as ops may be removed from 1395d462bf68SRolf Morel // the mapping by the TrackingRewriter while iteration is in progress. 1396d462bf68SRolf Morel SmallVector<SmallVector<MappedValue>> payloads; 1397d462bf68SRolf Morel detail::prepareValueMappings(payloads, getTargets(), state); 1398d462bf68SRolf Morel size_t numIterations = payloads.empty() ? 0 : payloads.front().size(); 1399d1ca1d01SGuillermo Callaghan bool withZipShortest = getWithZipShortest(); 1400a9efcbf4Smuneebkhan85 1401a9efcbf4Smuneebkhan85 // In case of `zip_shortest`, set the number of iterations to the 1402a9efcbf4Smuneebkhan85 // smallest payload in the targets. 1403d1ca1d01SGuillermo Callaghan if (withZipShortest) { 1404a9efcbf4Smuneebkhan85 numIterations = 1405a9efcbf4Smuneebkhan85 llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A, 1406a9efcbf4Smuneebkhan85 const SmallVector<MappedValue> &B) { 1407a9efcbf4Smuneebkhan85 return A.size() < B.size(); 1408a9efcbf4Smuneebkhan85 })->size(); 1409a9efcbf4Smuneebkhan85 1410a9efcbf4Smuneebkhan85 for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++) 1411a9efcbf4Smuneebkhan85 payloads[argIdx].resize(numIterations); 1412a9efcbf4Smuneebkhan85 } 1413d462bf68SRolf Morel 1414d462bf68SRolf Morel // As we will be "zipping" over them, check all payloads have the same size. 1415a9efcbf4Smuneebkhan85 // `zip_shortest` adjusts all payloads to the same size, so skip this check 1416a9efcbf4Smuneebkhan85 // when true. 1417d1ca1d01SGuillermo Callaghan for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size(); 1418a9efcbf4Smuneebkhan85 argIdx++) { 1419d462bf68SRolf Morel if (payloads[argIdx].size() != numIterations) { 1420d462bf68SRolf Morel return emitSilenceableError() 1421d462bf68SRolf Morel << "prior targets' payload size (" << numIterations 1422d462bf68SRolf Morel << ") differs from payload size (" << payloads[argIdx].size() 1423d462bf68SRolf Morel << ") of target " << getTargets()[argIdx]; 1424d462bf68SRolf Morel } 1425d462bf68SRolf Morel } 1426d462bf68SRolf Morel 1427d462bf68SRolf Morel // Start iterating, indexing into payloads to obtain the right arguments to 1428d462bf68SRolf Morel // call the body with - each slice of payloads at the same argument index 1429d462bf68SRolf Morel // corresponding to a tuple to use as the body's block arguments. 1430d462bf68SRolf Morel ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments(); 1431d462bf68SRolf Morel SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {}); 1432d462bf68SRolf Morel for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) { 1433bffec215SMatthias Springer auto scope = state.make_region_scope(getBody()); 1434d462bf68SRolf Morel // Set up arguments to the region's block. 1435d462bf68SRolf Morel for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) { 1436d462bf68SRolf Morel MappedValue argument = payloads[argIdx][iterIdx]; 1437d462bf68SRolf Morel // Note that each blockArg's handle gets associated with just a single 1438d462bf68SRolf Morel // element from the corresponding target's payload. 1439d462bf68SRolf Morel if (failed(state.mapBlockArgument(blockArg, {argument}))) 1440bba85ebdSAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 1441d462bf68SRolf Morel } 1442bffec215SMatthias Springer 1443c1e6caacSMatthias Springer // Execute loop body. 1444bffec215SMatthias Springer for (Operation &transform : getBody().front().without_terminator()) { 1445bffec215SMatthias Springer DiagnosedSilenceableFailure result = state.applyTransform( 1446d462bf68SRolf Morel llvm::cast<transform::TransformOpInterface>(transform)); 1447bffec215SMatthias Springer if (!result.succeeded()) 1448bffec215SMatthias Springer return result; 1449bffec215SMatthias Springer } 1450c1e6caacSMatthias Springer 1451d462bf68SRolf Morel // Append yielded payloads to corresponding results from prior iterations. 1452d462bf68SRolf Morel OperandRange yieldOperands = getYieldOp().getOperands(); 1453d462bf68SRolf Morel for (auto &&[result, yieldOperand, resTuple] : 1454d462bf68SRolf Morel llvm::zip_equal(getResults(), yieldOperands, zippedResults)) 1455d462bf68SRolf Morel // NB: each iteration we add any number of ops/vals/params to a result. 1456d462bf68SRolf Morel if (isa<TransformHandleTypeInterface>(result.getType())) 1457d462bf68SRolf Morel llvm::append_range(resTuple, state.getPayloadOps(yieldOperand)); 1458d462bf68SRolf Morel else if (isa<TransformValueHandleTypeInterface>(result.getType())) 1459d462bf68SRolf Morel llvm::append_range(resTuple, state.getPayloadValues(yieldOperand)); 1460d462bf68SRolf Morel else if (isa<TransformParamTypeInterface>(result.getType())) 1461d462bf68SRolf Morel llvm::append_range(resTuple, state.getParams(yieldOperand)); 1462d462bf68SRolf Morel else 1463d462bf68SRolf Morel assert(false && "unhandled handle type"); 1464c1e6caacSMatthias Springer } 1465c1e6caacSMatthias Springer 1466d462bf68SRolf Morel // Associate the accumulated result payloads to the op's actual results. 1467d462bf68SRolf Morel for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults)) 1468d462bf68SRolf Morel results.setMappedValues(llvm::cast<OpResult>(result), resPayload); 1469c1e6caacSMatthias Springer 1470bffec215SMatthias Springer return DiagnosedSilenceableFailure::success(); 1471bffec215SMatthias Springer } 1472bffec215SMatthias Springer 1473bffec215SMatthias Springer void transform::ForeachOp::getEffects( 1474bffec215SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1475d462bf68SRolf Morel // NB: this `zip` should be `zip_equal` - while this op's verifier catches 1476d462bf68SRolf Morel // arity errors, this method might get called before/in absence of `verify()`. 1477d462bf68SRolf Morel for (auto &&[target, blockArg] : 14782c1ae801Sdonald chen llvm::zip(getTargetsMutable(), getBody().front().getArguments())) { 1479d462bf68SRolf Morel BlockArgument blockArgument = blockArg; 1480bffec215SMatthias Springer if (any_of(getBody().front().without_terminator(), [&](Operation &op) { 1481d462bf68SRolf Morel return isHandleConsumed(blockArgument, 1482d462bf68SRolf Morel cast<TransformOpInterface>(&op)); 1483bffec215SMatthias Springer })) { 1484d462bf68SRolf Morel consumesHandle(target, effects); 1485bffec215SMatthias Springer } else { 1486d462bf68SRolf Morel onlyReadsHandle(target, effects); 1487d462bf68SRolf Morel } 1488bffec215SMatthias Springer } 1489c1e6caacSMatthias Springer 14904f63252dSMatthias Springer if (any_of(getBody().front().without_terminator(), [&](Operation &op) { 14914f63252dSMatthias Springer return doesModifyPayload(cast<TransformOpInterface>(&op)); 14924f63252dSMatthias Springer })) { 14934f63252dSMatthias Springer modifiesPayload(effects); 14944f63252dSMatthias Springer } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) { 14954f63252dSMatthias Springer return doesReadPayload(cast<TransformOpInterface>(&op)); 14964f63252dSMatthias Springer })) { 14974f63252dSMatthias Springer onlyReadsPayload(effects); 14984f63252dSMatthias Springer } 14994f63252dSMatthias Springer 15002c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 1501bffec215SMatthias Springer } 1502bffec215SMatthias Springer 1503bffec215SMatthias Springer void transform::ForeachOp::getSuccessorRegions( 15044dd744acSMarkus Böck RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 1505bffec215SMatthias Springer Region *bodyRegion = &getBody(); 15064dd744acSMarkus Böck if (point.isParent()) { 1507bffec215SMatthias Springer regions.emplace_back(bodyRegion, bodyRegion->getArguments()); 1508bffec215SMatthias Springer return; 1509bffec215SMatthias Springer } 1510bffec215SMatthias Springer 1511bffec215SMatthias Springer // Branch back to the region or the parent. 15124dd744acSMarkus Böck assert(point == getBody() && "unexpected region index"); 1513bffec215SMatthias Springer regions.emplace_back(bodyRegion, bodyRegion->getArguments()); 1514bffec215SMatthias Springer regions.emplace_back(); 1515bffec215SMatthias Springer } 1516bffec215SMatthias Springer 1517bffec215SMatthias Springer OperandRange 15184dd744acSMarkus Böck transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { 1519d462bf68SRolf Morel // Each block argument handle is mapped to a subset (one op to be precise) 1520d462bf68SRolf Morel // of the payload of the corresponding `targets` operand of ForeachOp. 15214dd744acSMarkus Böck assert(point == getBody() && "unexpected region index"); 1522bffec215SMatthias Springer return getOperation()->getOperands(); 1523bffec215SMatthias Springer } 1524bffec215SMatthias Springer 1525c1e6caacSMatthias Springer transform::YieldOp transform::ForeachOp::getYieldOp() { 1526c1e6caacSMatthias Springer return cast<transform::YieldOp>(getBody().front().getTerminator()); 1527c1e6caacSMatthias Springer } 1528c1e6caacSMatthias Springer 1529c1e6caacSMatthias Springer LogicalResult transform::ForeachOp::verify() { 1530d462bf68SRolf Morel for (auto [targetOpt, bodyArgOpt] : 1531d462bf68SRolf Morel llvm::zip_longest(getTargets(), getBody().front().getArguments())) { 1532d462bf68SRolf Morel if (!targetOpt || !bodyArgOpt) 1533d462bf68SRolf Morel return emitOpError() << "expects the same number of targets as the body " 1534d462bf68SRolf Morel "has block arguments"; 1535d462bf68SRolf Morel if (targetOpt.value().getType() != bodyArgOpt.value().getType()) 1536d462bf68SRolf Morel return emitOpError( 1537d462bf68SRolf Morel "expects co-indexed targets and the body's " 1538d462bf68SRolf Morel "block arguments to have the same op/value/param type"); 1539d462bf68SRolf Morel } 1540d462bf68SRolf Morel 1541d462bf68SRolf Morel for (auto [resultOpt, yieldOperandOpt] : 1542d462bf68SRolf Morel llvm::zip_longest(getResults(), getYieldOp().getOperands())) { 1543d462bf68SRolf Morel if (!resultOpt || !yieldOperandOpt) 1544c1e6caacSMatthias Springer return emitOpError() << "expects the same number of results as the " 1545d462bf68SRolf Morel "yield terminator has operands"; 1546d462bf68SRolf Morel if (resultOpt.value().getType() != yieldOperandOpt.value().getType()) 1547d462bf68SRolf Morel return emitOpError("expects co-indexed results and yield " 1548d462bf68SRolf Morel "operands to have the same op/value/param type"); 1549d462bf68SRolf Morel } 1550d462bf68SRolf Morel 1551c1e6caacSMatthias Springer return success(); 1552c1e6caacSMatthias Springer } 1553c1e6caacSMatthias Springer 1554bffec215SMatthias Springer //===----------------------------------------------------------------------===// 15554106557aSMatthias Springer // GetParentOp 1556cc6c1592SAlex Zinenko //===----------------------------------------------------------------------===// 1557cc6c1592SAlex Zinenko 15584106557aSMatthias Springer DiagnosedSilenceableFailure 15594106557aSMatthias Springer transform::GetParentOp::apply(transform::TransformRewriter &rewriter, 15604106557aSMatthias Springer transform::TransformResults &results, 15614106557aSMatthias Springer transform::TransformState &state) { 15624106557aSMatthias Springer SmallVector<Operation *> parents; 15634106557aSMatthias Springer DenseSet<Operation *> resultSet; 1564cc6c1592SAlex Zinenko for (Operation *target : state.getPayloadOps(getTarget())) { 156504736c7fSMatthias Springer Operation *parent = target; 156604736c7fSMatthias Springer for (int64_t i = 0, e = getNthParent(); i < e; ++i) { 156704736c7fSMatthias Springer parent = parent->getParentOp(); 156868033aaaSIngo Müller while (parent) { 15694106557aSMatthias Springer bool checkIsolatedFromAbove = 15704106557aSMatthias Springer !getIsolatedFromAbove() || 15714106557aSMatthias Springer parent->hasTrait<OpTrait::IsIsolatedFromAbove>(); 15724106557aSMatthias Springer bool checkOpName = !getOpName().has_value() || 15734106557aSMatthias Springer parent->getName().getStringRef() == *getOpName(); 15744106557aSMatthias Springer if (checkIsolatedFromAbove && checkOpName) 15754106557aSMatthias Springer break; 157668033aaaSIngo Müller parent = parent->getParentOp(); 157768033aaaSIngo Müller } 1578cc6c1592SAlex Zinenko if (!parent) { 157998341df0SNicolas Vasilache if (getAllowEmptyResults()) { 158098341df0SNicolas Vasilache results.set(llvm::cast<OpResult>(getResult()), parents); 158198341df0SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 158298341df0SNicolas Vasilache } 15831d45282aSAlex Zinenko DiagnosedSilenceableFailure diag = 15841d45282aSAlex Zinenko emitSilenceableError() 15854106557aSMatthias Springer << "could not find a parent op that matches all requirements"; 1586cc6c1592SAlex Zinenko diag.attachNote(target->getLoc()) << "target op"; 1587cc6c1592SAlex Zinenko return diag; 1588cc6c1592SAlex Zinenko } 158904736c7fSMatthias Springer } 15904106557aSMatthias Springer if (getDeduplicate()) { 159167e7f05aSKazu Hirata if (resultSet.insert(parent).second) 15924106557aSMatthias Springer parents.push_back(parent); 15934106557aSMatthias Springer } else { 15944106557aSMatthias Springer parents.push_back(parent); 15954106557aSMatthias Springer } 15964106557aSMatthias Springer } 15974106557aSMatthias Springer results.set(llvm::cast<OpResult>(getResult()), parents); 15981d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 1599cc6c1592SAlex Zinenko } 1600cc6c1592SAlex Zinenko 1601cc6c1592SAlex Zinenko //===----------------------------------------------------------------------===// 16024adf89fcSNicolas Vasilache // GetConsumersOfResult 16034adf89fcSNicolas Vasilache //===----------------------------------------------------------------------===// 16044adf89fcSNicolas Vasilache 16054adf89fcSNicolas Vasilache DiagnosedSilenceableFailure 1606c63d2b2cSMatthias Springer transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter, 1607c63d2b2cSMatthias Springer transform::TransformResults &results, 16084adf89fcSNicolas Vasilache transform::TransformState &state) { 16094adf89fcSNicolas Vasilache int64_t resultNumber = getResultNumber(); 16100e37ef08SMatthias Springer auto payloadOps = state.getPayloadOps(getTarget()); 16110e37ef08SMatthias Springer if (std::empty(payloadOps)) { 16120e37ef08SMatthias Springer results.set(cast<OpResult>(getResult()), {}); 16134adf89fcSNicolas Vasilache return DiagnosedSilenceableFailure::success(); 16144adf89fcSNicolas Vasilache } 16150e37ef08SMatthias Springer if (!llvm::hasSingleElement(payloadOps)) 16164adf89fcSNicolas Vasilache return emitDefiniteFailure() 16174adf89fcSNicolas Vasilache << "handle must be mapped to exactly one payload op"; 16184adf89fcSNicolas Vasilache 16190e37ef08SMatthias Springer Operation *target = *payloadOps.begin(); 16204adf89fcSNicolas Vasilache if (target->getNumResults() <= resultNumber) 16214adf89fcSNicolas Vasilache return emitDefiniteFailure() << "result number overflow"; 1622c1fa60b4STres Popp results.set(llvm::cast<OpResult>(getResult()), 16234adf89fcSNicolas Vasilache llvm::to_vector(target->getResult(resultNumber).getUsers())); 16244adf89fcSNicolas Vasilache return DiagnosedSilenceableFailure::success(); 16254adf89fcSNicolas Vasilache } 16264adf89fcSNicolas Vasilache 16274adf89fcSNicolas Vasilache //===----------------------------------------------------------------------===// 16284cf936d0SMatthias Springer // GetDefiningOp 16294cf936d0SMatthias Springer //===----------------------------------------------------------------------===// 16304cf936d0SMatthias Springer 16314cf936d0SMatthias Springer DiagnosedSilenceableFailure 1632c63d2b2cSMatthias Springer transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter, 1633c63d2b2cSMatthias Springer transform::TransformResults &results, 16344cf936d0SMatthias Springer transform::TransformState &state) { 16354cf936d0SMatthias Springer SmallVector<Operation *> definingOps; 16364cf936d0SMatthias Springer for (Value v : state.getPayloadValues(getTarget())) { 1637c1fa60b4STres Popp if (llvm::isa<BlockArgument>(v)) { 16384cf936d0SMatthias Springer DiagnosedSilenceableFailure diag = 16394cf936d0SMatthias Springer emitSilenceableError() << "cannot get defining op of block argument"; 16404cf936d0SMatthias Springer diag.attachNote(v.getLoc()) << "target value"; 16414cf936d0SMatthias Springer return diag; 16424cf936d0SMatthias Springer } 16434cf936d0SMatthias Springer definingOps.push_back(v.getDefiningOp()); 16444cf936d0SMatthias Springer } 1645c1fa60b4STres Popp results.set(llvm::cast<OpResult>(getResult()), definingOps); 16464cf936d0SMatthias Springer return DiagnosedSilenceableFailure::success(); 16474cf936d0SMatthias Springer } 16484cf936d0SMatthias Springer 16494cf936d0SMatthias Springer //===----------------------------------------------------------------------===// 1650ecd9dc04SNicolas Vasilache // GetProducerOfOperand 1651ecd9dc04SNicolas Vasilache //===----------------------------------------------------------------------===// 1652ecd9dc04SNicolas Vasilache 1653ecd9dc04SNicolas Vasilache DiagnosedSilenceableFailure 1654c63d2b2cSMatthias Springer transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter, 1655c63d2b2cSMatthias Springer transform::TransformResults &results, 1656ecd9dc04SNicolas Vasilache transform::TransformState &state) { 1657ecd9dc04SNicolas Vasilache int64_t operandNumber = getOperandNumber(); 1658ecd9dc04SNicolas Vasilache SmallVector<Operation *> producers; 1659ecd9dc04SNicolas Vasilache for (Operation *target : state.getPayloadOps(getTarget())) { 1660ecd9dc04SNicolas Vasilache Operation *producer = 1661ecd9dc04SNicolas Vasilache target->getNumOperands() <= operandNumber 1662ecd9dc04SNicolas Vasilache ? nullptr 1663ecd9dc04SNicolas Vasilache : target->getOperand(operandNumber).getDefiningOp(); 1664ecd9dc04SNicolas Vasilache if (!producer) { 1665ecd9dc04SNicolas Vasilache DiagnosedSilenceableFailure diag = 1666ecd9dc04SNicolas Vasilache emitSilenceableError() 1667ecd9dc04SNicolas Vasilache << "could not find a producer for operand number: " << operandNumber 1668ecd9dc04SNicolas Vasilache << " of " << *target; 1669ecd9dc04SNicolas Vasilache diag.attachNote(target->getLoc()) << "target op"; 1670ecd9dc04SNicolas Vasilache return diag; 1671ecd9dc04SNicolas Vasilache } 1672ecd9dc04SNicolas Vasilache producers.push_back(producer); 1673ecd9dc04SNicolas Vasilache } 1674c1fa60b4STres Popp results.set(llvm::cast<OpResult>(getResult()), producers); 1675ecd9dc04SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 1676ecd9dc04SNicolas Vasilache } 1677ecd9dc04SNicolas Vasilache 1678ecd9dc04SNicolas Vasilache //===----------------------------------------------------------------------===// 16795caab8bbSQuinn Dawkins // GetOperandOp 16805caab8bbSQuinn Dawkins //===----------------------------------------------------------------------===// 16815caab8bbSQuinn Dawkins 16825caab8bbSQuinn Dawkins DiagnosedSilenceableFailure 16835caab8bbSQuinn Dawkins transform::GetOperandOp::apply(transform::TransformRewriter &rewriter, 16845caab8bbSQuinn Dawkins transform::TransformResults &results, 16855caab8bbSQuinn Dawkins transform::TransformState &state) { 16865caab8bbSQuinn Dawkins SmallVector<Value> operands; 16875caab8bbSQuinn Dawkins for (Operation *target : state.getPayloadOps(getTarget())) { 16885caab8bbSQuinn Dawkins SmallVector<int64_t> operandPositions; 16895caab8bbSQuinn Dawkins DiagnosedSilenceableFailure diag = expandTargetSpecification( 16905caab8bbSQuinn Dawkins getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), 16915caab8bbSQuinn Dawkins target->getNumOperands(), operandPositions); 16925caab8bbSQuinn Dawkins if (diag.isSilenceableFailure()) { 16935caab8bbSQuinn Dawkins diag.attachNote(target->getLoc()) 16945caab8bbSQuinn Dawkins << "while considering positions of this payload operation"; 16955caab8bbSQuinn Dawkins return diag; 16965caab8bbSQuinn Dawkins } 16975caab8bbSQuinn Dawkins llvm::append_range(operands, 16985caab8bbSQuinn Dawkins llvm::map_range(operandPositions, [&](int64_t pos) { 16995caab8bbSQuinn Dawkins return target->getOperand(pos); 17005caab8bbSQuinn Dawkins })); 17015caab8bbSQuinn Dawkins } 17025caab8bbSQuinn Dawkins results.setValues(cast<OpResult>(getResult()), operands); 17035caab8bbSQuinn Dawkins return DiagnosedSilenceableFailure::success(); 17045caab8bbSQuinn Dawkins } 17055caab8bbSQuinn Dawkins 17065caab8bbSQuinn Dawkins LogicalResult transform::GetOperandOp::verify() { 17075caab8bbSQuinn Dawkins return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), 17085caab8bbSQuinn Dawkins getIsInverted(), getIsAll()); 17095caab8bbSQuinn Dawkins } 17105caab8bbSQuinn Dawkins 17115caab8bbSQuinn Dawkins //===----------------------------------------------------------------------===// 17123ef062a4SMatthias Springer // GetResultOp 17133ef062a4SMatthias Springer //===----------------------------------------------------------------------===// 17143ef062a4SMatthias Springer 17153ef062a4SMatthias Springer DiagnosedSilenceableFailure 1716c63d2b2cSMatthias Springer transform::GetResultOp::apply(transform::TransformRewriter &rewriter, 1717c63d2b2cSMatthias Springer transform::TransformResults &results, 17183ef062a4SMatthias Springer transform::TransformState &state) { 17193ef062a4SMatthias Springer SmallVector<Value> opResults; 17203ef062a4SMatthias Springer for (Operation *target : state.getPayloadOps(getTarget())) { 17215caab8bbSQuinn Dawkins SmallVector<int64_t> resultPositions; 17225caab8bbSQuinn Dawkins DiagnosedSilenceableFailure diag = expandTargetSpecification( 17235caab8bbSQuinn Dawkins getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), 17245caab8bbSQuinn Dawkins target->getNumResults(), resultPositions); 17255caab8bbSQuinn Dawkins if (diag.isSilenceableFailure()) { 17265caab8bbSQuinn Dawkins diag.attachNote(target->getLoc()) 17275caab8bbSQuinn Dawkins << "while considering positions of this payload operation"; 17283ef062a4SMatthias Springer return diag; 17293ef062a4SMatthias Springer } 17305caab8bbSQuinn Dawkins llvm::append_range(opResults, 17315caab8bbSQuinn Dawkins llvm::map_range(resultPositions, [&](int64_t pos) { 17325caab8bbSQuinn Dawkins return target->getResult(pos); 17335caab8bbSQuinn Dawkins })); 17343ef062a4SMatthias Springer } 17355caab8bbSQuinn Dawkins results.setValues(cast<OpResult>(getResult()), opResults); 17363ef062a4SMatthias Springer return DiagnosedSilenceableFailure::success(); 17373ef062a4SMatthias Springer } 17383ef062a4SMatthias Springer 17395caab8bbSQuinn Dawkins LogicalResult transform::GetResultOp::verify() { 17405caab8bbSQuinn Dawkins return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), 17415caab8bbSQuinn Dawkins getIsInverted(), getIsAll()); 17425caab8bbSQuinn Dawkins } 17435caab8bbSQuinn Dawkins 17443ef062a4SMatthias Springer //===----------------------------------------------------------------------===// 1745dd81c6b8SAlex Zinenko // GetTypeOp 1746dd81c6b8SAlex Zinenko //===----------------------------------------------------------------------===// 1747dd81c6b8SAlex Zinenko 1748dd81c6b8SAlex Zinenko void transform::GetTypeOp::getEffects( 1749dd81c6b8SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 17502c1ae801Sdonald chen onlyReadsHandle(getValueMutable(), effects); 17512c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 1752dd81c6b8SAlex Zinenko onlyReadsPayload(effects); 1753dd81c6b8SAlex Zinenko } 1754dd81c6b8SAlex Zinenko 1755dd81c6b8SAlex Zinenko DiagnosedSilenceableFailure 1756dd81c6b8SAlex Zinenko transform::GetTypeOp::apply(transform::TransformRewriter &rewriter, 1757dd81c6b8SAlex Zinenko transform::TransformResults &results, 1758dd81c6b8SAlex Zinenko transform::TransformState &state) { 1759dd81c6b8SAlex Zinenko SmallVector<Attribute> params; 1760085075a5SMatthias Springer for (Value value : state.getPayloadValues(getValue())) { 1761dd81c6b8SAlex Zinenko Type type = value.getType(); 1762dd81c6b8SAlex Zinenko if (getElemental()) { 1763dd81c6b8SAlex Zinenko if (auto shaped = dyn_cast<ShapedType>(type)) { 1764dd81c6b8SAlex Zinenko type = shaped.getElementType(); 1765dd81c6b8SAlex Zinenko } 1766dd81c6b8SAlex Zinenko } 1767dd81c6b8SAlex Zinenko params.push_back(TypeAttr::get(type)); 1768dd81c6b8SAlex Zinenko } 1769a5757c5bSChristian Sigg results.setParams(cast<OpResult>(getResult()), params); 1770dd81c6b8SAlex Zinenko return DiagnosedSilenceableFailure::success(); 1771dd81c6b8SAlex Zinenko } 1772dd81c6b8SAlex Zinenko 1773dd81c6b8SAlex Zinenko //===----------------------------------------------------------------------===// 1774fb409a28SAlex Zinenko // IncludeOp 1775fb409a28SAlex Zinenko //===----------------------------------------------------------------------===// 1776fb409a28SAlex Zinenko 1777fb409a28SAlex Zinenko /// Applies the transform ops contained in `block`. Maps `results` to the same 1778fb409a28SAlex Zinenko /// values as the operands of the block terminator. 1779fb409a28SAlex Zinenko static DiagnosedSilenceableFailure 1780fb409a28SAlex Zinenko applySequenceBlock(Block &block, transform::FailurePropagationMode mode, 1781fb409a28SAlex Zinenko transform::TransformState &state, 1782fb409a28SAlex Zinenko transform::TransformResults &results) { 1783fb409a28SAlex Zinenko // Apply the sequenced ops one by one. 1784fb409a28SAlex Zinenko for (Operation &transform : block.without_terminator()) { 1785fb409a28SAlex Zinenko DiagnosedSilenceableFailure result = 1786fb409a28SAlex Zinenko state.applyTransform(cast<transform::TransformOpInterface>(transform)); 1787fb409a28SAlex Zinenko if (result.isDefiniteFailure()) 1788fb409a28SAlex Zinenko return result; 1789fb409a28SAlex Zinenko 1790fb409a28SAlex Zinenko if (result.isSilenceableFailure()) { 1791fb409a28SAlex Zinenko if (mode == transform::FailurePropagationMode::Propagate) { 1792fb409a28SAlex Zinenko // Propagate empty results in case of early exit. 1793fb409a28SAlex Zinenko forwardEmptyOperands(&block, state, results); 1794fb409a28SAlex Zinenko return result; 1795fb409a28SAlex Zinenko } 1796fb409a28SAlex Zinenko (void)result.silence(); 1797fb409a28SAlex Zinenko } 1798fb409a28SAlex Zinenko } 1799fb409a28SAlex Zinenko 1800fb409a28SAlex Zinenko // Forward the operation mapping for values yielded from the sequence to the 1801fb409a28SAlex Zinenko // values produced by the sequence op. 180263c9d2b1SAlex Zinenko transform::detail::forwardTerminatorOperands(&block, state, results); 1803fb409a28SAlex Zinenko return DiagnosedSilenceableFailure::success(); 1804fb409a28SAlex Zinenko } 1805fb409a28SAlex Zinenko 1806fb409a28SAlex Zinenko DiagnosedSilenceableFailure 1807c63d2b2cSMatthias Springer transform::IncludeOp::apply(transform::TransformRewriter &rewriter, 1808c63d2b2cSMatthias Springer transform::TransformResults &results, 1809fb409a28SAlex Zinenko transform::TransformState &state) { 1810fb409a28SAlex Zinenko auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>( 1811fb409a28SAlex Zinenko getOperation(), getTarget()); 1812fb409a28SAlex Zinenko assert(callee && "unverified reference to unknown symbol"); 1813fb409a28SAlex Zinenko 181492c69468SAlex Zinenko if (callee.isExternal()) 181592c69468SAlex Zinenko return emitDefiniteFailure() << "unresolved external named sequence"; 181692c69468SAlex Zinenko 1817fb409a28SAlex Zinenko // Map operands to block arguments. 1818fb409a28SAlex Zinenko SmallVector<SmallVector<MappedValue>> mappings; 1819fb409a28SAlex Zinenko detail::prepareValueMappings(mappings, getOperands(), state); 182022259281SMatthias Springer auto scope = state.make_region_scope(callee.getBody()); 1821fb409a28SAlex Zinenko for (auto &&[arg, map] : 1822fb409a28SAlex Zinenko llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) { 1823fb409a28SAlex Zinenko if (failed(state.mapBlockArgument(arg, map))) 1824fb409a28SAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 1825fb409a28SAlex Zinenko } 1826fb409a28SAlex Zinenko 1827fb409a28SAlex Zinenko DiagnosedSilenceableFailure result = applySequenceBlock( 1828fb409a28SAlex Zinenko callee.getBody().front(), getFailurePropagationMode(), state, results); 1829fb409a28SAlex Zinenko mappings.clear(); 1830fb409a28SAlex Zinenko detail::prepareValueMappings( 1831fb409a28SAlex Zinenko mappings, callee.getBody().front().getTerminator()->getOperands(), state); 1832fb409a28SAlex Zinenko for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings)) 1833fb409a28SAlex Zinenko results.setMappedValues(result, mapping); 1834fb409a28SAlex Zinenko return result; 1835fb409a28SAlex Zinenko } 1836fb409a28SAlex Zinenko 1837fb409a28SAlex Zinenko static DiagnosedSilenceableFailure 1838135e5bf8SAlex Zinenko verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings); 1839fb409a28SAlex Zinenko 1840fb409a28SAlex Zinenko void transform::IncludeOp::getEffects( 1841fb409a28SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 184241109341SAlex Zinenko // Always mark as modifying the payload. 184341109341SAlex Zinenko // TODO: a mechanism to annotate effects on payload. Even when all handles are 184441109341SAlex Zinenko // only read, the payload may still be modified, so we currently stay on the 184541109341SAlex Zinenko // conservative side and always indicate modification. This may prevent some 184641109341SAlex Zinenko // code reordering. 184741109341SAlex Zinenko modifiesPayload(effects); 184841109341SAlex Zinenko 184941109341SAlex Zinenko // Results are always produced. 18502c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 185141109341SAlex Zinenko 185241109341SAlex Zinenko // Adds default effects to operands and results. This will be added if 185341109341SAlex Zinenko // preconditions fail so the trait verifier doesn't complain about missing 185441109341SAlex Zinenko // effects and the real precondition failure is reported later on. 18552c1ae801Sdonald chen auto defaultEffects = [&] { 18562c1ae801Sdonald chen onlyReadsHandle(getOperation()->getOpOperands(), effects); 18572c1ae801Sdonald chen }; 185841109341SAlex Zinenko 1859fb409a28SAlex Zinenko // Bail if the callee is unknown. This may run as part of the verification 1860fb409a28SAlex Zinenko // process before we verified the validity of the callee or of this op. 1861fb409a28SAlex Zinenko auto target = 1862fb409a28SAlex Zinenko getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName()); 1863fb409a28SAlex Zinenko if (!target) 186441109341SAlex Zinenko return defaultEffects(); 1865fb409a28SAlex Zinenko auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>( 1866fb409a28SAlex Zinenko getOperation(), getTarget()); 1867fb409a28SAlex Zinenko if (!callee) 186841109341SAlex Zinenko return defaultEffects(); 1869fb409a28SAlex Zinenko DiagnosedSilenceableFailure earlyVerifierResult = 1870135e5bf8SAlex Zinenko verifyNamedSequenceOp(callee, /*emitWarnings=*/false); 1871fb409a28SAlex Zinenko if (!earlyVerifierResult.succeeded()) { 1872fb409a28SAlex Zinenko (void)earlyVerifierResult.silence(); 187341109341SAlex Zinenko return defaultEffects(); 1874fb409a28SAlex Zinenko } 1875fb409a28SAlex Zinenko 187641109341SAlex Zinenko for (unsigned i = 0, e = getNumOperands(); i < e; ++i) { 187741109341SAlex Zinenko if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName)) 18782c1ae801Sdonald chen consumesHandle(getOperation()->getOpOperand(i), effects); 187941109341SAlex Zinenko else 18802c1ae801Sdonald chen onlyReadsHandle(getOperation()->getOpOperand(i), effects); 188141109341SAlex Zinenko } 1882fb409a28SAlex Zinenko } 1883fb409a28SAlex Zinenko 1884fb409a28SAlex Zinenko LogicalResult 1885fb409a28SAlex Zinenko transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1886fb409a28SAlex Zinenko // Access through indirection and do additional checking because this may be 1887fb409a28SAlex Zinenko // running before the main op verifier. 1888fb409a28SAlex Zinenko auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target"); 1889fb409a28SAlex Zinenko if (!targetAttr) 1890fb409a28SAlex Zinenko return emitOpError() << "expects a 'target' symbol reference attribute"; 1891fb409a28SAlex Zinenko 1892fb409a28SAlex Zinenko auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>( 1893fb409a28SAlex Zinenko *this, targetAttr); 1894fb409a28SAlex Zinenko if (!target) 1895fb409a28SAlex Zinenko return emitOpError() << "does not reference a named transform sequence"; 1896fb409a28SAlex Zinenko 1897fb409a28SAlex Zinenko FunctionType fnType = target.getFunctionType(); 1898fb409a28SAlex Zinenko if (fnType.getNumInputs() != getNumOperands()) 1899fb409a28SAlex Zinenko return emitError("incorrect number of operands for callee"); 1900fb409a28SAlex Zinenko 1901fb409a28SAlex Zinenko for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { 1902fb409a28SAlex Zinenko if (getOperand(i).getType() != fnType.getInput(i)) { 1903fb409a28SAlex Zinenko return emitOpError("operand type mismatch: expected operand type ") 1904fb409a28SAlex Zinenko << fnType.getInput(i) << ", but provided " 1905fb409a28SAlex Zinenko << getOperand(i).getType() << " for operand number " << i; 1906fb409a28SAlex Zinenko } 1907fb409a28SAlex Zinenko } 1908fb409a28SAlex Zinenko 1909fb409a28SAlex Zinenko if (fnType.getNumResults() != getNumResults()) 1910fb409a28SAlex Zinenko return emitError("incorrect number of results for callee"); 1911fb409a28SAlex Zinenko 1912fb409a28SAlex Zinenko for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { 1913fb409a28SAlex Zinenko Type resultType = getResult(i).getType(); 1914fb409a28SAlex Zinenko Type funcType = fnType.getResult(i); 191563c9d2b1SAlex Zinenko if (!implementSameTransformInterface(resultType, funcType)) { 1916fb409a28SAlex Zinenko return emitOpError() << "type of result #" << i 1917fb409a28SAlex Zinenko << " must implement the same transform dialect " 1918fb409a28SAlex Zinenko "interface as the corresponding callee result"; 1919fb409a28SAlex Zinenko } 1920fb409a28SAlex Zinenko } 1921fb409a28SAlex Zinenko 192263c9d2b1SAlex Zinenko return verifyFunctionLikeConsumeAnnotations( 1923135e5bf8SAlex Zinenko cast<FunctionOpInterface>(*target), /*emitWarnings=*/false, 192441109341SAlex Zinenko /*alsoVerifyInternal=*/true) 192541109341SAlex Zinenko .checkAndReport(); 1926fb409a28SAlex Zinenko } 1927fb409a28SAlex Zinenko 1928fb409a28SAlex Zinenko //===----------------------------------------------------------------------===// 192998341df0SNicolas Vasilache // MatchOperationEmptyOp 193098341df0SNicolas Vasilache //===----------------------------------------------------------------------===// 193198341df0SNicolas Vasilache 193298341df0SNicolas Vasilache DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation( 193398341df0SNicolas Vasilache ::std::optional<::mlir::Operation *> maybeCurrent, 193498341df0SNicolas Vasilache transform::TransformResults &results, transform::TransformState &state) { 193598341df0SNicolas Vasilache if (!maybeCurrent.has_value()) { 19368483d18bSNicolas Vasilache DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; }); 193798341df0SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 193898341df0SNicolas Vasilache } 19398483d18bSNicolas Vasilache DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; }); 194098341df0SNicolas Vasilache return emitSilenceableError() << "operation is not empty"; 194198341df0SNicolas Vasilache } 194298341df0SNicolas Vasilache 194398341df0SNicolas Vasilache //===----------------------------------------------------------------------===// 19443fe7127dSAlex Zinenko // MatchOperationNameOp 19453fe7127dSAlex Zinenko //===----------------------------------------------------------------------===// 19463fe7127dSAlex Zinenko 19473fe7127dSAlex Zinenko DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation( 19483fe7127dSAlex Zinenko Operation *current, transform::TransformResults &results, 19493fe7127dSAlex Zinenko transform::TransformState &state) { 19503fe7127dSAlex Zinenko StringRef currentOpName = current->getName().getStringRef(); 19513fe7127dSAlex Zinenko for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) { 19523fe7127dSAlex Zinenko if (acceptedAttr.getValue() == currentOpName) 19533fe7127dSAlex Zinenko return DiagnosedSilenceableFailure::success(); 19543fe7127dSAlex Zinenko } 19553fe7127dSAlex Zinenko return emitSilenceableError() << "wrong operation name"; 19563fe7127dSAlex Zinenko } 19573fe7127dSAlex Zinenko 19583fe7127dSAlex Zinenko //===----------------------------------------------------------------------===// 19593fe7127dSAlex Zinenko // MatchParamCmpIOp 19603fe7127dSAlex Zinenko //===----------------------------------------------------------------------===// 19613fe7127dSAlex Zinenko 19623fe7127dSAlex Zinenko DiagnosedSilenceableFailure 1963c63d2b2cSMatthias Springer transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter, 1964c63d2b2cSMatthias Springer transform::TransformResults &results, 19653fe7127dSAlex Zinenko transform::TransformState &state) { 196670ebc78eSMehdi Amini auto signedAPIntAsString = [&](const APInt &value) { 19673fe7127dSAlex Zinenko std::string str; 19683fe7127dSAlex Zinenko llvm::raw_string_ostream os(str); 19693fe7127dSAlex Zinenko value.print(os, /*isSigned=*/true); 1970884221edSJOE1994 return str; 19713fe7127dSAlex Zinenko }; 19723fe7127dSAlex Zinenko 19733fe7127dSAlex Zinenko ArrayRef<Attribute> params = state.getParams(getParam()); 19743fe7127dSAlex Zinenko ArrayRef<Attribute> references = state.getParams(getReference()); 19753fe7127dSAlex Zinenko 19763fe7127dSAlex Zinenko if (params.size() != references.size()) { 19773fe7127dSAlex Zinenko return emitSilenceableError() 19783fe7127dSAlex Zinenko << "parameters have different payload lengths (" << params.size() 19793fe7127dSAlex Zinenko << " vs " << references.size() << ")"; 19803fe7127dSAlex Zinenko } 19813fe7127dSAlex Zinenko 19823fe7127dSAlex Zinenko for (auto &&[i, param, reference] : llvm::enumerate(params, references)) { 1983c1fa60b4STres Popp auto intAttr = llvm::dyn_cast<IntegerAttr>(param); 1984c1fa60b4STres Popp auto refAttr = llvm::dyn_cast<IntegerAttr>(reference); 19853fe7127dSAlex Zinenko if (!intAttr || !refAttr) { 19863fe7127dSAlex Zinenko return emitDefiniteFailure() 19873fe7127dSAlex Zinenko << "non-integer parameter value not expected"; 19883fe7127dSAlex Zinenko } 19893fe7127dSAlex Zinenko if (intAttr.getType() != refAttr.getType()) { 19903fe7127dSAlex Zinenko return emitDefiniteFailure() 19913fe7127dSAlex Zinenko << "mismatching integer attribute types in parameter #" << i; 19923fe7127dSAlex Zinenko } 19933fe7127dSAlex Zinenko APInt value = intAttr.getValue(); 19943fe7127dSAlex Zinenko APInt refValue = refAttr.getValue(); 19953fe7127dSAlex Zinenko 19963fe7127dSAlex Zinenko // TODO: this copy will not be necessary in C++20. 19973fe7127dSAlex Zinenko int64_t position = i; 19983fe7127dSAlex Zinenko auto reportError = [&](StringRef direction) { 19993fe7127dSAlex Zinenko DiagnosedSilenceableFailure diag = 20003fe7127dSAlex Zinenko emitSilenceableError() << "expected parameter to be " << direction 20013fe7127dSAlex Zinenko << " " << signedAPIntAsString(refValue) 20023fe7127dSAlex Zinenko << ", got " << signedAPIntAsString(value); 20033fe7127dSAlex Zinenko diag.attachNote(getParam().getLoc()) 20043fe7127dSAlex Zinenko << "value # " << position 20053fe7127dSAlex Zinenko << " associated with the parameter defined here"; 20063fe7127dSAlex Zinenko return diag; 20073fe7127dSAlex Zinenko }; 20083fe7127dSAlex Zinenko 20093fe7127dSAlex Zinenko switch (getPredicate()) { 20103fe7127dSAlex Zinenko case MatchCmpIPredicate::eq: 20113fe7127dSAlex Zinenko if (value.eq(refValue)) 20123fe7127dSAlex Zinenko break; 20133fe7127dSAlex Zinenko return reportError("equal to"); 20143fe7127dSAlex Zinenko case MatchCmpIPredicate::ne: 20153fe7127dSAlex Zinenko if (value.ne(refValue)) 20163fe7127dSAlex Zinenko break; 20173fe7127dSAlex Zinenko return reportError("not equal to"); 20183fe7127dSAlex Zinenko case MatchCmpIPredicate::lt: 20193fe7127dSAlex Zinenko if (value.slt(refValue)) 20203fe7127dSAlex Zinenko break; 20213fe7127dSAlex Zinenko return reportError("less than"); 20223fe7127dSAlex Zinenko case MatchCmpIPredicate::le: 20233fe7127dSAlex Zinenko if (value.sle(refValue)) 20243fe7127dSAlex Zinenko break; 20253fe7127dSAlex Zinenko return reportError("less than or equal to"); 20263fe7127dSAlex Zinenko case MatchCmpIPredicate::gt: 20273fe7127dSAlex Zinenko if (value.sgt(refValue)) 20283fe7127dSAlex Zinenko break; 20293fe7127dSAlex Zinenko return reportError("greater than"); 20303fe7127dSAlex Zinenko case MatchCmpIPredicate::ge: 20313fe7127dSAlex Zinenko if (value.sge(refValue)) 20323fe7127dSAlex Zinenko break; 20333fe7127dSAlex Zinenko return reportError("greater than or equal to"); 20343fe7127dSAlex Zinenko } 20353fe7127dSAlex Zinenko } 20363fe7127dSAlex Zinenko return DiagnosedSilenceableFailure::success(); 20373fe7127dSAlex Zinenko } 20383fe7127dSAlex Zinenko 20393fe7127dSAlex Zinenko void transform::MatchParamCmpIOp::getEffects( 20403fe7127dSAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 20412c1ae801Sdonald chen onlyReadsHandle(getParamMutable(), effects); 20422c1ae801Sdonald chen onlyReadsHandle(getReferenceMutable(), effects); 20433fe7127dSAlex Zinenko } 20443fe7127dSAlex Zinenko 20453fe7127dSAlex Zinenko //===----------------------------------------------------------------------===// 20463fe7127dSAlex Zinenko // ParamConstantOp 20473fe7127dSAlex Zinenko //===----------------------------------------------------------------------===// 20483fe7127dSAlex Zinenko 20493fe7127dSAlex Zinenko DiagnosedSilenceableFailure 2050c63d2b2cSMatthias Springer transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter, 2051c63d2b2cSMatthias Springer transform::TransformResults &results, 20523fe7127dSAlex Zinenko transform::TransformState &state) { 20533fe7127dSAlex Zinenko results.setParams(cast<OpResult>(getParam()), {getValue()}); 20543fe7127dSAlex Zinenko return DiagnosedSilenceableFailure::success(); 20553fe7127dSAlex Zinenko } 20563fe7127dSAlex Zinenko 20573fe7127dSAlex Zinenko //===----------------------------------------------------------------------===// 20588e03bfc3SAlex Zinenko // MergeHandlesOp 20598e03bfc3SAlex Zinenko //===----------------------------------------------------------------------===// 20608e03bfc3SAlex Zinenko 20618e03bfc3SAlex Zinenko DiagnosedSilenceableFailure 2062c63d2b2cSMatthias Springer transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter, 2063c63d2b2cSMatthias Springer transform::TransformResults &results, 20648e03bfc3SAlex Zinenko transform::TransformState &state) { 206519380396SQuinn Dawkins ValueRange handles = getHandles(); 206619380396SQuinn Dawkins if (isa<TransformHandleTypeInterface>(handles.front().getType())) { 20678e03bfc3SAlex Zinenko SmallVector<Operation *> operations; 206819380396SQuinn Dawkins for (Value operand : handles) 20698e03bfc3SAlex Zinenko llvm::append_range(operations, state.getPayloadOps(operand)); 20708e03bfc3SAlex Zinenko if (!getDeduplicate()) { 2071c1fa60b4STres Popp results.set(llvm::cast<OpResult>(getResult()), operations); 20728e03bfc3SAlex Zinenko return DiagnosedSilenceableFailure::success(); 20738e03bfc3SAlex Zinenko } 20748e03bfc3SAlex Zinenko 20758e03bfc3SAlex Zinenko SetVector<Operation *> uniqued(operations.begin(), operations.end()); 2076c1fa60b4STres Popp results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef()); 20778e03bfc3SAlex Zinenko return DiagnosedSilenceableFailure::success(); 20788e03bfc3SAlex Zinenko } 20798e03bfc3SAlex Zinenko 208019380396SQuinn Dawkins if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) { 208119380396SQuinn Dawkins SmallVector<Attribute> attrs; 208219380396SQuinn Dawkins for (Value attribute : handles) 208319380396SQuinn Dawkins llvm::append_range(attrs, state.getParams(attribute)); 208419380396SQuinn Dawkins if (!getDeduplicate()) { 208519380396SQuinn Dawkins results.setParams(cast<OpResult>(getResult()), attrs); 208619380396SQuinn Dawkins return DiagnosedSilenceableFailure::success(); 208719380396SQuinn Dawkins } 208819380396SQuinn Dawkins 208919380396SQuinn Dawkins SetVector<Attribute> uniqued(attrs.begin(), attrs.end()); 209019380396SQuinn Dawkins results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef()); 209119380396SQuinn Dawkins return DiagnosedSilenceableFailure::success(); 209219380396SQuinn Dawkins } 209319380396SQuinn Dawkins 209419380396SQuinn Dawkins assert( 209519380396SQuinn Dawkins llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) && 209619380396SQuinn Dawkins "expected value handle type"); 209719380396SQuinn Dawkins SmallVector<Value> payloadValues; 209819380396SQuinn Dawkins for (Value value : handles) 209919380396SQuinn Dawkins llvm::append_range(payloadValues, state.getPayloadValues(value)); 210019380396SQuinn Dawkins if (!getDeduplicate()) { 210119380396SQuinn Dawkins results.setValues(cast<OpResult>(getResult()), payloadValues); 210219380396SQuinn Dawkins return DiagnosedSilenceableFailure::success(); 210319380396SQuinn Dawkins } 210419380396SQuinn Dawkins 210519380396SQuinn Dawkins SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end()); 210619380396SQuinn Dawkins results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef()); 210719380396SQuinn Dawkins return DiagnosedSilenceableFailure::success(); 210819380396SQuinn Dawkins } 210919380396SQuinn Dawkins 21104299be1aSAlex Zinenko bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() { 21114299be1aSAlex Zinenko // Handles may be the same if deduplicating is enabled. 21124299be1aSAlex Zinenko return getDeduplicate(); 21134299be1aSAlex Zinenko } 21144299be1aSAlex Zinenko 21158e03bfc3SAlex Zinenko void transform::MergeHandlesOp::getEffects( 21168e03bfc3SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 21172c1ae801Sdonald chen onlyReadsHandle(getHandlesMutable(), effects); 21182c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 21198e03bfc3SAlex Zinenko 21208e03bfc3SAlex Zinenko // There are no effects on the Payload IR as this is only a handle 21218e03bfc3SAlex Zinenko // manipulation. 21228e03bfc3SAlex Zinenko } 21238e03bfc3SAlex Zinenko 21247df76121SMarkus Böck OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) { 21258e03bfc3SAlex Zinenko if (getDeduplicate() || getHandles().size() != 1) 21268e03bfc3SAlex Zinenko return {}; 21278e03bfc3SAlex Zinenko 21288e03bfc3SAlex Zinenko // If deduplication is not required and there is only one operand, it can be 21298e03bfc3SAlex Zinenko // used directly instead of merging. 21308e03bfc3SAlex Zinenko return getHandles().front(); 21318e03bfc3SAlex Zinenko } 21328e03bfc3SAlex Zinenko 21338e03bfc3SAlex Zinenko //===----------------------------------------------------------------------===// 2134fb409a28SAlex Zinenko // NamedSequenceOp 2135fb409a28SAlex Zinenko //===----------------------------------------------------------------------===// 2136fb409a28SAlex Zinenko 2137fb409a28SAlex Zinenko DiagnosedSilenceableFailure 2138c63d2b2cSMatthias Springer transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter, 2139c63d2b2cSMatthias Springer transform::TransformResults &results, 2140fb409a28SAlex Zinenko transform::TransformState &state) { 21411bf08709SNicolas Vasilache if (isExternal()) 21421bf08709SNicolas Vasilache return emitDefiniteFailure() << "unresolved external named sequence"; 21431bf08709SNicolas Vasilache 21441bf08709SNicolas Vasilache // Map the entry block argument to the list of operations. 21451bf08709SNicolas Vasilache // Note: this is the same implementation as PossibleTopLevelTransformOp but 21461bf08709SNicolas Vasilache // without attaching the interface / trait since that is tailored to a 21471bf08709SNicolas Vasilache // dangling top-level op that does not get "called". 21481bf08709SNicolas Vasilache auto scope = state.make_region_scope(getBody()); 21491bf08709SNicolas Vasilache if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments( 21501bf08709SNicolas Vasilache state, this->getOperation(), getBody()))) 21511bf08709SNicolas Vasilache return DiagnosedSilenceableFailure::definiteFailure(); 21521bf08709SNicolas Vasilache 21531bf08709SNicolas Vasilache return applySequenceBlock(getBody().front(), 21541bf08709SNicolas Vasilache FailurePropagationMode::Propagate, state, results); 2155fb409a28SAlex Zinenko } 2156fb409a28SAlex Zinenko 2157fb409a28SAlex Zinenko void transform::NamedSequenceOp::getEffects( 2158fb409a28SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 2159fb409a28SAlex Zinenko 2160fb409a28SAlex Zinenko ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser, 2161fb409a28SAlex Zinenko OperationState &result) { 2162fb409a28SAlex Zinenko return function_interface_impl::parseFunctionOp( 2163fb409a28SAlex Zinenko parser, result, /*allowVariadic=*/false, 2164fb409a28SAlex Zinenko getFunctionTypeAttrName(result.name), 2165fb409a28SAlex Zinenko [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results, 2166fb409a28SAlex Zinenko function_interface_impl::VariadicFlag, 2167fb409a28SAlex Zinenko std::string &) { return builder.getFunctionType(inputs, results); }, 2168fb409a28SAlex Zinenko getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 2169fb409a28SAlex Zinenko } 2170fb409a28SAlex Zinenko 2171fb409a28SAlex Zinenko void transform::NamedSequenceOp::print(OpAsmPrinter &printer) { 2172fb409a28SAlex Zinenko function_interface_impl::printFunctionOp( 2173fb409a28SAlex Zinenko printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false, 2174fb409a28SAlex Zinenko getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(), 2175fb409a28SAlex Zinenko getResAttrsAttrName()); 2176fb409a28SAlex Zinenko } 2177fb409a28SAlex Zinenko 217863c9d2b1SAlex Zinenko /// Verifies that a symbol function-like transform dialect operation has the 217963c9d2b1SAlex Zinenko /// signature and the terminator that have conforming types, i.e., types 218063c9d2b1SAlex Zinenko /// implementing the same transform dialect type interface. If `allowExternal` 218163c9d2b1SAlex Zinenko /// is set, allow external symbols (declarations) and don't check the terminator 218263c9d2b1SAlex Zinenko /// as it may not exist. 218363c9d2b1SAlex Zinenko static DiagnosedSilenceableFailure 218463c9d2b1SAlex Zinenko verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) { 218563c9d2b1SAlex Zinenko if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) { 218663c9d2b1SAlex Zinenko DiagnosedSilenceableFailure diag = 218763c9d2b1SAlex Zinenko emitSilenceableFailure(op) 218863c9d2b1SAlex Zinenko << "cannot be defined inside another transform op"; 218963c9d2b1SAlex Zinenko diag.attachNote(parent.getLoc()) << "ancestor transform op"; 219063c9d2b1SAlex Zinenko return diag; 219163c9d2b1SAlex Zinenko } 219263c9d2b1SAlex Zinenko 219363c9d2b1SAlex Zinenko if (op.isExternal() || op.getFunctionBody().empty()) { 219463c9d2b1SAlex Zinenko if (allowExternal) 219563c9d2b1SAlex Zinenko return DiagnosedSilenceableFailure::success(); 219663c9d2b1SAlex Zinenko 219763c9d2b1SAlex Zinenko return emitSilenceableFailure(op) << "cannot be external"; 219863c9d2b1SAlex Zinenko } 219963c9d2b1SAlex Zinenko 220063c9d2b1SAlex Zinenko if (op.getFunctionBody().front().empty()) 220163c9d2b1SAlex Zinenko return emitSilenceableFailure(op) << "expected a non-empty body block"; 220263c9d2b1SAlex Zinenko 220363c9d2b1SAlex Zinenko Operation *terminator = &op.getFunctionBody().front().back(); 220463c9d2b1SAlex Zinenko if (!isa<transform::YieldOp>(terminator)) { 220563c9d2b1SAlex Zinenko DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) 220663c9d2b1SAlex Zinenko << "expected '" 220763c9d2b1SAlex Zinenko << transform::YieldOp::getOperationName() 220863c9d2b1SAlex Zinenko << "' as terminator"; 220963c9d2b1SAlex Zinenko diag.attachNote(terminator->getLoc()) << "terminator"; 221063c9d2b1SAlex Zinenko return diag; 221163c9d2b1SAlex Zinenko } 221263c9d2b1SAlex Zinenko 221363c9d2b1SAlex Zinenko if (terminator->getNumOperands() != op.getResultTypes().size()) { 221463c9d2b1SAlex Zinenko return emitSilenceableFailure(terminator) 221563c9d2b1SAlex Zinenko << "expected terminator to have as many operands as the parent op " 221663c9d2b1SAlex Zinenko "has results"; 221763c9d2b1SAlex Zinenko } 221863c9d2b1SAlex Zinenko for (auto [i, operandType, resultType] : llvm::zip_equal( 221963c9d2b1SAlex Zinenko llvm::seq<unsigned>(0, terminator->getNumOperands()), 222063c9d2b1SAlex Zinenko terminator->getOperands().getType(), op.getResultTypes())) { 222163c9d2b1SAlex Zinenko if (operandType == resultType) 222263c9d2b1SAlex Zinenko continue; 222363c9d2b1SAlex Zinenko return emitSilenceableFailure(terminator) 222463c9d2b1SAlex Zinenko << "the type of the terminator operand #" << i 222563c9d2b1SAlex Zinenko << " must match the type of the corresponding parent op result (" 222663c9d2b1SAlex Zinenko << operandType << " vs " << resultType << ")"; 222763c9d2b1SAlex Zinenko } 222863c9d2b1SAlex Zinenko 222963c9d2b1SAlex Zinenko return DiagnosedSilenceableFailure::success(); 223063c9d2b1SAlex Zinenko } 223163c9d2b1SAlex Zinenko 2232fb409a28SAlex Zinenko /// Verification of a NamedSequenceOp. This does not report the error 2233fb409a28SAlex Zinenko /// immediately, so it can be used to check for op's well-formedness before the 2234fb409a28SAlex Zinenko /// verifier runs, e.g., during trait verification. 2235fb409a28SAlex Zinenko static DiagnosedSilenceableFailure 2236135e5bf8SAlex Zinenko verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) { 2237fb409a28SAlex Zinenko if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) { 2238fb409a28SAlex Zinenko if (!parent->getAttr( 2239fb409a28SAlex Zinenko transform::TransformDialect::kWithNamedSequenceAttrName)) { 2240fb409a28SAlex Zinenko DiagnosedSilenceableFailure diag = 2241fb409a28SAlex Zinenko emitSilenceableFailure(op) 2242fb409a28SAlex Zinenko << "expects the parent symbol table to have the '" 2243fb409a28SAlex Zinenko << transform::TransformDialect::kWithNamedSequenceAttrName 2244fb409a28SAlex Zinenko << "' attribute"; 2245fb409a28SAlex Zinenko diag.attachNote(parent->getLoc()) << "symbol table operation"; 2246fb409a28SAlex Zinenko return diag; 2247fb409a28SAlex Zinenko } 2248fb409a28SAlex Zinenko } 2249fb409a28SAlex Zinenko 2250fb409a28SAlex Zinenko if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) { 2251fb409a28SAlex Zinenko DiagnosedSilenceableFailure diag = 2252fb409a28SAlex Zinenko emitSilenceableFailure(op) 2253fb409a28SAlex Zinenko << "cannot be defined inside another transform op"; 2254fb409a28SAlex Zinenko diag.attachNote(parent.getLoc()) << "ancestor transform op"; 2255fb409a28SAlex Zinenko return diag; 2256fb409a28SAlex Zinenko } 2257fb409a28SAlex Zinenko 225892c69468SAlex Zinenko if (op.isExternal() || op.getBody().empty()) 2259135e5bf8SAlex Zinenko return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op), 2260135e5bf8SAlex Zinenko emitWarnings); 226192c69468SAlex Zinenko 2262fb409a28SAlex Zinenko if (op.getBody().front().empty()) 2263fb409a28SAlex Zinenko return emitSilenceableFailure(op) << "expected a non-empty body block"; 2264fb409a28SAlex Zinenko 2265fb409a28SAlex Zinenko Operation *terminator = &op.getBody().front().back(); 2266fb409a28SAlex Zinenko if (!isa<transform::YieldOp>(terminator)) { 2267fb409a28SAlex Zinenko DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) 2268fb409a28SAlex Zinenko << "expected '" 2269fb409a28SAlex Zinenko << transform::YieldOp::getOperationName() 2270fb409a28SAlex Zinenko << "' as terminator"; 2271fb409a28SAlex Zinenko diag.attachNote(terminator->getLoc()) << "terminator"; 2272fb409a28SAlex Zinenko return diag; 2273fb409a28SAlex Zinenko } 2274fb409a28SAlex Zinenko 2275fb409a28SAlex Zinenko if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) { 2276fb409a28SAlex Zinenko return emitSilenceableFailure(terminator) 2277fb409a28SAlex Zinenko << "expected terminator to have as many operands as the parent op " 2278fb409a28SAlex Zinenko "has results"; 2279fb409a28SAlex Zinenko } 2280fb409a28SAlex Zinenko for (auto [i, operandType, resultType] : 2281fb409a28SAlex Zinenko llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()), 2282fb409a28SAlex Zinenko terminator->getOperands().getType(), 2283fb409a28SAlex Zinenko op.getFunctionType().getResults())) { 2284fb409a28SAlex Zinenko if (operandType == resultType) 2285fb409a28SAlex Zinenko continue; 2286fb409a28SAlex Zinenko return emitSilenceableFailure(terminator) 2287fb409a28SAlex Zinenko << "the type of the terminator operand #" << i 2288fb409a28SAlex Zinenko << " must match the type of the corresponding parent op result (" 2289fb409a28SAlex Zinenko << operandType << " vs " << resultType << ")"; 2290fb409a28SAlex Zinenko } 2291fb409a28SAlex Zinenko 229263c9d2b1SAlex Zinenko auto funcOp = cast<FunctionOpInterface>(*op); 229363c9d2b1SAlex Zinenko DiagnosedSilenceableFailure diag = 2294135e5bf8SAlex Zinenko verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings); 229563c9d2b1SAlex Zinenko if (!diag.succeeded()) 229663c9d2b1SAlex Zinenko return diag; 229763c9d2b1SAlex Zinenko 229863c9d2b1SAlex Zinenko return verifyYieldingSingleBlockOp(funcOp, 229963c9d2b1SAlex Zinenko /*allowExternal=*/true); 2300fb409a28SAlex Zinenko } 2301fb409a28SAlex Zinenko 2302fb409a28SAlex Zinenko LogicalResult transform::NamedSequenceOp::verify() { 2303fb409a28SAlex Zinenko // Actual verification happens in a separate function for reusability. 2304135e5bf8SAlex Zinenko return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport(); 2305fb409a28SAlex Zinenko } 2306fb409a28SAlex Zinenko 23070c935894SNicolas Vasilache template <typename FnTy> 23080c935894SNicolas Vasilache static void buildSequenceBody(OpBuilder &builder, OperationState &state, 23090c935894SNicolas Vasilache Type bbArgType, TypeRange extraBindingTypes, 23100c935894SNicolas Vasilache FnTy bodyBuilder) { 23110c935894SNicolas Vasilache SmallVector<Type> types; 23120c935894SNicolas Vasilache types.reserve(1 + extraBindingTypes.size()); 23130c935894SNicolas Vasilache types.push_back(bbArgType); 23140c935894SNicolas Vasilache llvm::append_range(types, extraBindingTypes); 23150c935894SNicolas Vasilache 23160c935894SNicolas Vasilache OpBuilder::InsertionGuard guard(builder); 23170c935894SNicolas Vasilache Region *region = state.regions.back().get(); 23180c935894SNicolas Vasilache Block *bodyBlock = 23190c935894SNicolas Vasilache builder.createBlock(region, region->begin(), types, 23200c935894SNicolas Vasilache SmallVector<Location>(types.size(), state.location)); 23210c935894SNicolas Vasilache 23220c935894SNicolas Vasilache // Populate body. 23230c935894SNicolas Vasilache builder.setInsertionPointToStart(bodyBlock); 23240c935894SNicolas Vasilache if constexpr (llvm::function_traits<FnTy>::num_args == 3) { 23250c935894SNicolas Vasilache bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); 23260c935894SNicolas Vasilache } else { 23270c935894SNicolas Vasilache bodyBuilder(builder, state.location, bodyBlock->getArgument(0), 23280c935894SNicolas Vasilache bodyBlock->getArguments().drop_front()); 23290c935894SNicolas Vasilache } 23300c935894SNicolas Vasilache } 23310c935894SNicolas Vasilache 23320c935894SNicolas Vasilache void transform::NamedSequenceOp::build(OpBuilder &builder, 23330c935894SNicolas Vasilache OperationState &state, StringRef symName, 23340c935894SNicolas Vasilache Type rootType, TypeRange resultTypes, 23350c935894SNicolas Vasilache SequenceBodyBuilderFn bodyBuilder, 23360c935894SNicolas Vasilache ArrayRef<NamedAttribute> attrs, 23370c935894SNicolas Vasilache ArrayRef<DictionaryAttr> argAttrs) { 23380c935894SNicolas Vasilache state.addAttribute(SymbolTable::getSymbolAttrName(), 23390c935894SNicolas Vasilache builder.getStringAttr(symName)); 23400c935894SNicolas Vasilache state.addAttribute(getFunctionTypeAttrName(state.name), 23418483d18bSNicolas Vasilache TypeAttr::get(FunctionType::get(builder.getContext(), 23428483d18bSNicolas Vasilache rootType, resultTypes))); 23430c935894SNicolas Vasilache state.attributes.append(attrs.begin(), attrs.end()); 23440c935894SNicolas Vasilache state.addRegion(); 23450c935894SNicolas Vasilache 23460c935894SNicolas Vasilache buildSequenceBody(builder, state, rootType, 23470c935894SNicolas Vasilache /*extraBindingTypes=*/TypeRange(), bodyBuilder); 23480c935894SNicolas Vasilache } 23490c935894SNicolas Vasilache 2350fb409a28SAlex Zinenko //===----------------------------------------------------------------------===// 2351f90b6090SOleksandr "Alex" Zinenko // NumAssociationsOp 2352f90b6090SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 2353f90b6090SOleksandr "Alex" Zinenko 2354f90b6090SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure 2355f90b6090SOleksandr "Alex" Zinenko transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter, 2356f90b6090SOleksandr "Alex" Zinenko transform::TransformResults &results, 2357f90b6090SOleksandr "Alex" Zinenko transform::TransformState &state) { 2358f90b6090SOleksandr "Alex" Zinenko size_t numAssociations = 2359f90b6090SOleksandr "Alex" Zinenko llvm::TypeSwitch<Type, size_t>(getHandle().getType()) 2360f90b6090SOleksandr "Alex" Zinenko .Case([&](TransformHandleTypeInterface opHandle) { 2361f90b6090SOleksandr "Alex" Zinenko return llvm::range_size(state.getPayloadOps(getHandle())); 2362f90b6090SOleksandr "Alex" Zinenko }) 2363f90b6090SOleksandr "Alex" Zinenko .Case([&](TransformValueHandleTypeInterface valueHandle) { 2364f90b6090SOleksandr "Alex" Zinenko return llvm::range_size(state.getPayloadValues(getHandle())); 2365f90b6090SOleksandr "Alex" Zinenko }) 2366f90b6090SOleksandr "Alex" Zinenko .Case([&](TransformParamTypeInterface param) { 2367f90b6090SOleksandr "Alex" Zinenko return llvm::range_size(state.getParams(getHandle())); 2368f90b6090SOleksandr "Alex" Zinenko }) 2369f90b6090SOleksandr "Alex" Zinenko .Default([](Type) { 2370f90b6090SOleksandr "Alex" Zinenko llvm_unreachable("unknown kind of transform dialect type"); 2371f90b6090SOleksandr "Alex" Zinenko return 0; 2372f90b6090SOleksandr "Alex" Zinenko }); 2373a5757c5bSChristian Sigg results.setParams(cast<OpResult>(getNum()), 2374f90b6090SOleksandr "Alex" Zinenko rewriter.getI64IntegerAttr(numAssociations)); 2375f90b6090SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::success(); 2376f90b6090SOleksandr "Alex" Zinenko } 2377f90b6090SOleksandr "Alex" Zinenko 2378f90b6090SOleksandr "Alex" Zinenko LogicalResult transform::NumAssociationsOp::verify() { 2379f90b6090SOleksandr "Alex" Zinenko // Verify that the result type accepts an i64 attribute as payload. 2380a5757c5bSChristian Sigg auto resultType = cast<TransformParamTypeInterface>(getNum().getType()); 2381f90b6090SOleksandr "Alex" Zinenko return resultType 2382f90b6090SOleksandr "Alex" Zinenko .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)}) 2383f90b6090SOleksandr "Alex" Zinenko .checkAndReport(); 2384f90b6090SOleksandr "Alex" Zinenko } 2385f90b6090SOleksandr "Alex" Zinenko 2386f90b6090SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 2387894fdbc7SMatthias Springer // SelectOp 2388894fdbc7SMatthias Springer //===----------------------------------------------------------------------===// 2389894fdbc7SMatthias Springer 2390894fdbc7SMatthias Springer DiagnosedSilenceableFailure 2391894fdbc7SMatthias Springer transform::SelectOp::apply(transform::TransformRewriter &rewriter, 2392894fdbc7SMatthias Springer transform::TransformResults &results, 2393894fdbc7SMatthias Springer transform::TransformState &state) { 2394894fdbc7SMatthias Springer SmallVector<Operation *> result; 2395894fdbc7SMatthias Springer auto payloadOps = state.getPayloadOps(getTarget()); 2396894fdbc7SMatthias Springer for (Operation *op : payloadOps) { 2397894fdbc7SMatthias Springer if (op->getName().getStringRef() == getOpName()) 2398894fdbc7SMatthias Springer result.push_back(op); 2399894fdbc7SMatthias Springer } 2400894fdbc7SMatthias Springer results.set(cast<OpResult>(getResult()), result); 2401894fdbc7SMatthias Springer return DiagnosedSilenceableFailure::success(); 2402894fdbc7SMatthias Springer } 2403894fdbc7SMatthias Springer 2404894fdbc7SMatthias Springer //===----------------------------------------------------------------------===// 2405288529e7SMatthias Springer // SplitHandleOp 2406af664e44SNicolas Vasilache //===----------------------------------------------------------------------===// 2407af664e44SNicolas Vasilache 2408288529e7SMatthias Springer void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result, 2409288529e7SMatthias Springer Value target, int64_t numResultHandles) { 2410c8fab80dSNicolas Vasilache result.addOperands(target); 241194d608d4SAlex Zinenko result.addTypes(SmallVector<Type>(numResultHandles, target.getType())); 2412c8fab80dSNicolas Vasilache } 2413c8fab80dSNicolas Vasilache 2414af664e44SNicolas Vasilache DiagnosedSilenceableFailure 2415c63d2b2cSMatthias Springer transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, 2416c63d2b2cSMatthias Springer transform::TransformResults &results, 2417af664e44SNicolas Vasilache transform::TransformState &state) { 24181c352e66SOleksandr "Alex" Zinenko int64_t numPayloads = 24191c352e66SOleksandr "Alex" Zinenko llvm::TypeSwitch<Type, int64_t>(getHandle().getType()) 24201c352e66SOleksandr "Alex" Zinenko .Case<TransformHandleTypeInterface>([&](auto x) { 24211c352e66SOleksandr "Alex" Zinenko return llvm::range_size(state.getPayloadOps(getHandle())); 24221c352e66SOleksandr "Alex" Zinenko }) 24231c352e66SOleksandr "Alex" Zinenko .Case<TransformValueHandleTypeInterface>([&](auto x) { 24241c352e66SOleksandr "Alex" Zinenko return llvm::range_size(state.getPayloadValues(getHandle())); 24251c352e66SOleksandr "Alex" Zinenko }) 24261c352e66SOleksandr "Alex" Zinenko .Case<TransformParamTypeInterface>([&](auto x) { 24271c352e66SOleksandr "Alex" Zinenko return llvm::range_size(state.getParams(getHandle())); 24281c352e66SOleksandr "Alex" Zinenko }) 24291c352e66SOleksandr "Alex" Zinenko .Default([](auto x) { 24301c352e66SOleksandr "Alex" Zinenko llvm_unreachable("unknown transform dialect type interface"); 24311c352e66SOleksandr "Alex" Zinenko return -1; 24321c352e66SOleksandr "Alex" Zinenko }); 24331c352e66SOleksandr "Alex" Zinenko 2434709098fbSMatthias Springer auto produceNumOpsError = [&]() { 2435af664e44SNicolas Vasilache return emitSilenceableError() 2436709098fbSMatthias Springer << getHandle() << " expected to contain " << this->getNumResults() 24371c352e66SOleksandr "Alex" Zinenko << " payloads but it contains " << numPayloads << " payloads"; 2438709098fbSMatthias Springer }; 2439288529e7SMatthias Springer 2440709098fbSMatthias Springer // Fail if there are more payload ops than results and no overflow result was 2441709098fbSMatthias Springer // specified. 24421c352e66SOleksandr "Alex" Zinenko if (numPayloads > getNumResults() && !getOverflowResult().has_value()) 2443709098fbSMatthias Springer return produceNumOpsError(); 2444709098fbSMatthias Springer 2445709098fbSMatthias Springer // Fail if there are more results than payload ops. Unless: 2446709098fbSMatthias Springer // - "fail_on_payload_too_small" is set to "false", or 2447709098fbSMatthias Springer // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops. 24481c352e66SOleksandr "Alex" Zinenko if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() && 24491c352e66SOleksandr "Alex" Zinenko (numPayloads != 0 || !getPassThroughEmptyHandle())) 2450709098fbSMatthias Springer return produceNumOpsError(); 2451709098fbSMatthias Springer 24521c352e66SOleksandr "Alex" Zinenko // Distribute payloads. 24531c352e66SOleksandr "Alex" Zinenko SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {}); 2454709098fbSMatthias Springer if (getOverflowResult()) 24551c352e66SOleksandr "Alex" Zinenko resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults()); 24561c352e66SOleksandr "Alex" Zinenko 24571c352e66SOleksandr "Alex" Zinenko auto container = [&]() { 24581c352e66SOleksandr "Alex" Zinenko if (isa<TransformHandleTypeInterface>(getHandle().getType())) { 24591c352e66SOleksandr "Alex" Zinenko return llvm::map_to_vector( 24601c352e66SOleksandr "Alex" Zinenko state.getPayloadOps(getHandle()), 24611c352e66SOleksandr "Alex" Zinenko [](Operation *op) -> MappedValue { return op; }); 24621c352e66SOleksandr "Alex" Zinenko } 24631c352e66SOleksandr "Alex" Zinenko if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) { 24641c352e66SOleksandr "Alex" Zinenko return llvm::map_to_vector(state.getPayloadValues(getHandle()), 24651c352e66SOleksandr "Alex" Zinenko [](Value v) -> MappedValue { return v; }); 24661c352e66SOleksandr "Alex" Zinenko } 24671c352e66SOleksandr "Alex" Zinenko assert(isa<TransformParamTypeInterface>(getHandle().getType()) && 24681c352e66SOleksandr "Alex" Zinenko "unsupported kind of transform dialect type"); 24691c352e66SOleksandr "Alex" Zinenko return llvm::map_to_vector(state.getParams(getHandle()), 24701c352e66SOleksandr "Alex" Zinenko [](Attribute a) -> MappedValue { return a; }); 24711c352e66SOleksandr "Alex" Zinenko }(); 24721c352e66SOleksandr "Alex" Zinenko 24731c352e66SOleksandr "Alex" Zinenko for (auto &&en : llvm::enumerate(container)) { 2474709098fbSMatthias Springer int64_t resultNum = en.index(); 2475709098fbSMatthias Springer if (resultNum >= getNumResults()) 2476709098fbSMatthias Springer resultNum = *getOverflowResult(); 2477709098fbSMatthias Springer resultHandles[resultNum].push_back(en.value()); 2478709098fbSMatthias Springer } 2479709098fbSMatthias Springer 2480709098fbSMatthias Springer // Set transform op results. 2481709098fbSMatthias Springer for (auto &&it : llvm::enumerate(resultHandles)) 24821c352e66SOleksandr "Alex" Zinenko results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())), 24831c352e66SOleksandr "Alex" Zinenko it.value()); 2484288529e7SMatthias Springer 2485af664e44SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 2486af664e44SNicolas Vasilache } 2487af664e44SNicolas Vasilache 2488288529e7SMatthias Springer void transform::SplitHandleOp::getEffects( 2489af664e44SNicolas Vasilache SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 24902c1ae801Sdonald chen onlyReadsHandle(getHandleMutable(), effects); 24912c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 2492af664e44SNicolas Vasilache // There are no effects on the Payload IR as this is only a handle 2493af664e44SNicolas Vasilache // manipulation. 2494af664e44SNicolas Vasilache } 2495af664e44SNicolas Vasilache 2496709098fbSMatthias Springer LogicalResult transform::SplitHandleOp::verify() { 2497709098fbSMatthias Springer if (getOverflowResult().has_value() && 249880815dfbSlong.chen !(*getOverflowResult() < getNumResults())) 2499709098fbSMatthias Springer return emitOpError("overflow_result is not a valid result index"); 25001c352e66SOleksandr "Alex" Zinenko 25011c352e66SOleksandr "Alex" Zinenko for (Type resultType : getResultTypes()) { 25021c352e66SOleksandr "Alex" Zinenko if (implementSameTransformInterface(getHandle().getType(), resultType)) 25031c352e66SOleksandr "Alex" Zinenko continue; 25041c352e66SOleksandr "Alex" Zinenko 25051c352e66SOleksandr "Alex" Zinenko return emitOpError("expects result types to implement the same transform " 25061c352e66SOleksandr "Alex" Zinenko "interface as the operand type"); 25071c352e66SOleksandr "Alex" Zinenko } 25081c352e66SOleksandr "Alex" Zinenko 2509709098fbSMatthias Springer return success(); 2510709098fbSMatthias Springer } 2511709098fbSMatthias Springer 2512af664e44SNicolas Vasilache //===----------------------------------------------------------------------===// 251300d1a1a2SAlex Zinenko // ReplicateOp 251400d1a1a2SAlex Zinenko //===----------------------------------------------------------------------===// 251500d1a1a2SAlex Zinenko 251600d1a1a2SAlex Zinenko DiagnosedSilenceableFailure 2517c63d2b2cSMatthias Springer transform::ReplicateOp::apply(transform::TransformRewriter &rewriter, 2518c63d2b2cSMatthias Springer transform::TransformResults &results, 251900d1a1a2SAlex Zinenko transform::TransformState &state) { 25200e37ef08SMatthias Springer unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); 252100d1a1a2SAlex Zinenko for (const auto &en : llvm::enumerate(getHandles())) { 252200d1a1a2SAlex Zinenko Value handle = en.value(); 25230e37ef08SMatthias Springer if (isa<TransformHandleTypeInterface>(handle.getType())) { 25240e37ef08SMatthias Springer SmallVector<Operation *> current = 25250e37ef08SMatthias Springer llvm::to_vector(state.getPayloadOps(handle)); 252600d1a1a2SAlex Zinenko SmallVector<Operation *> payload; 252700d1a1a2SAlex Zinenko payload.reserve(numRepetitions * current.size()); 252800d1a1a2SAlex Zinenko for (unsigned i = 0; i < numRepetitions; ++i) 252900d1a1a2SAlex Zinenko llvm::append_range(payload, current); 2530c1fa60b4STres Popp results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload); 253188c5027bSAlex Zinenko } else { 2532c1fa60b4STres Popp assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) && 253388c5027bSAlex Zinenko "expected param type"); 253488c5027bSAlex Zinenko ArrayRef<Attribute> current = state.getParams(handle); 253588c5027bSAlex Zinenko SmallVector<Attribute> params; 253688c5027bSAlex Zinenko params.reserve(numRepetitions * current.size()); 253788c5027bSAlex Zinenko for (unsigned i = 0; i < numRepetitions; ++i) 253888c5027bSAlex Zinenko llvm::append_range(params, current); 2539c1fa60b4STres Popp results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]), 2540c1fa60b4STres Popp params); 254188c5027bSAlex Zinenko } 254200d1a1a2SAlex Zinenko } 254300d1a1a2SAlex Zinenko return DiagnosedSilenceableFailure::success(); 254400d1a1a2SAlex Zinenko } 254500d1a1a2SAlex Zinenko 254600d1a1a2SAlex Zinenko void transform::ReplicateOp::getEffects( 254700d1a1a2SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 25482c1ae801Sdonald chen onlyReadsHandle(getPatternMutable(), effects); 25492c1ae801Sdonald chen onlyReadsHandle(getHandlesMutable(), effects); 25502c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 255100d1a1a2SAlex Zinenko } 255200d1a1a2SAlex Zinenko 255300d1a1a2SAlex Zinenko //===----------------------------------------------------------------------===// 255430f22429SAlex Zinenko // SequenceOp 255530f22429SAlex Zinenko //===----------------------------------------------------------------------===// 255630f22429SAlex Zinenko 25571d45282aSAlex Zinenko DiagnosedSilenceableFailure 2558c63d2b2cSMatthias Springer transform::SequenceOp::apply(transform::TransformRewriter &rewriter, 2559c63d2b2cSMatthias Springer transform::TransformResults &results, 25600eb403adSAlex Zinenko transform::TransformState &state) { 25610eb403adSAlex Zinenko // Map the entry block argument to the list of operations. 25620eb403adSAlex Zinenko auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 2563bba85ebdSAlex Zinenko if (failed(mapBlockArguments(state))) 2564bba85ebdSAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 25650eb403adSAlex Zinenko 2566fb409a28SAlex Zinenko return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state, 2567fb409a28SAlex Zinenko results); 25680eb403adSAlex Zinenko } 25690eb403adSAlex Zinenko 2570b9e40cdeSAlex Zinenko static ParseResult parseSequenceOpOperands( 257105423905SKazu Hirata OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root, 2572b9e40cdeSAlex Zinenko Type &rootType, 2573b9e40cdeSAlex Zinenko SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings, 2574b9e40cdeSAlex Zinenko SmallVectorImpl<Type> &extraBindingTypes) { 2575b9e40cdeSAlex Zinenko OpAsmParser::UnresolvedOperand rootOperand; 2576b9e40cdeSAlex Zinenko OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand); 2577b9e40cdeSAlex Zinenko if (!hasRoot.has_value()) { 2578b9e40cdeSAlex Zinenko root = std::nullopt; 2579b9e40cdeSAlex Zinenko return success(); 2580b9e40cdeSAlex Zinenko } 2581b9e40cdeSAlex Zinenko if (failed(hasRoot.value())) 2582b9e40cdeSAlex Zinenko return failure(); 2583b9e40cdeSAlex Zinenko root = rootOperand; 2584b9e40cdeSAlex Zinenko 2585b9e40cdeSAlex Zinenko if (succeeded(parser.parseOptionalComma())) { 2586b9e40cdeSAlex Zinenko if (failed(parser.parseOperandList(extraBindings))) 2587b9e40cdeSAlex Zinenko return failure(); 2588b9e40cdeSAlex Zinenko } 2589b9e40cdeSAlex Zinenko if (failed(parser.parseColon())) 2590b9e40cdeSAlex Zinenko return failure(); 2591b9e40cdeSAlex Zinenko 2592b9e40cdeSAlex Zinenko // The paren is truly optional. 2593b9e40cdeSAlex Zinenko (void)parser.parseOptionalLParen(); 2594b9e40cdeSAlex Zinenko 2595b9e40cdeSAlex Zinenko if (failed(parser.parseType(rootType))) { 2596b9e40cdeSAlex Zinenko return failure(); 2597b9e40cdeSAlex Zinenko } 2598b9e40cdeSAlex Zinenko 2599b9e40cdeSAlex Zinenko if (!extraBindings.empty()) { 2600b9e40cdeSAlex Zinenko if (parser.parseComma() || parser.parseTypeList(extraBindingTypes)) 2601b9e40cdeSAlex Zinenko return failure(); 2602b9e40cdeSAlex Zinenko } 2603b9e40cdeSAlex Zinenko 2604b9e40cdeSAlex Zinenko if (extraBindingTypes.size() != extraBindings.size()) { 2605b9e40cdeSAlex Zinenko return parser.emitError(parser.getNameLoc(), 2606b9e40cdeSAlex Zinenko "expected types to be provided for all operands"); 2607b9e40cdeSAlex Zinenko } 2608b9e40cdeSAlex Zinenko 2609b9e40cdeSAlex Zinenko // The paren is truly optional. 2610b9e40cdeSAlex Zinenko (void)parser.parseOptionalRParen(); 2611b9e40cdeSAlex Zinenko return success(); 2612b9e40cdeSAlex Zinenko } 2613b9e40cdeSAlex Zinenko 2614b9e40cdeSAlex Zinenko static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, 2615b9e40cdeSAlex Zinenko Value root, Type rootType, 2616b9e40cdeSAlex Zinenko ValueRange extraBindings, 2617b9e40cdeSAlex Zinenko TypeRange extraBindingTypes) { 2618b9e40cdeSAlex Zinenko if (!root) 2619b9e40cdeSAlex Zinenko return; 2620b9e40cdeSAlex Zinenko 2621b9e40cdeSAlex Zinenko printer << root; 2622b9e40cdeSAlex Zinenko bool hasExtras = !extraBindings.empty(); 2623b9e40cdeSAlex Zinenko if (hasExtras) { 2624b9e40cdeSAlex Zinenko printer << ", "; 2625b9e40cdeSAlex Zinenko printer.printOperands(extraBindings); 2626b9e40cdeSAlex Zinenko } 2627b9e40cdeSAlex Zinenko 2628b9e40cdeSAlex Zinenko printer << " : "; 2629b9e40cdeSAlex Zinenko if (hasExtras) 2630b9e40cdeSAlex Zinenko printer << "("; 2631b9e40cdeSAlex Zinenko 2632b9e40cdeSAlex Zinenko printer << rootType; 2633b9e40cdeSAlex Zinenko if (hasExtras) { 2634b9e40cdeSAlex Zinenko printer << ", "; 2635b9e40cdeSAlex Zinenko llvm::interleaveComma(extraBindingTypes, printer.getStream()); 2636b9e40cdeSAlex Zinenko printer << ")"; 2637b9e40cdeSAlex Zinenko } 2638b9e40cdeSAlex Zinenko } 2639b9e40cdeSAlex Zinenko 264040a8bd63SAlex Zinenko /// Returns `true` if the given op operand may be consuming the handle value in 264140a8bd63SAlex Zinenko /// the Transform IR. That is, if it may have a Free effect on it. 264240a8bd63SAlex Zinenko static bool isValueUsePotentialConsumer(OpOperand &use) { 264340a8bd63SAlex Zinenko // Conservatively assume the effect being present in absence of the interface. 2644e15b855eSAlex Zinenko auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner()); 2645e15b855eSAlex Zinenko if (!iface) 264640a8bd63SAlex Zinenko return true; 264740a8bd63SAlex Zinenko 2648e15b855eSAlex Zinenko return isHandleConsumed(use.get(), iface); 264940a8bd63SAlex Zinenko } 265040a8bd63SAlex Zinenko 265140a8bd63SAlex Zinenko LogicalResult 265240a8bd63SAlex Zinenko checkDoubleConsume(Value value, 265340a8bd63SAlex Zinenko function_ref<InFlightDiagnostic()> reportError) { 265440a8bd63SAlex Zinenko OpOperand *potentialConsumer = nullptr; 265540a8bd63SAlex Zinenko for (OpOperand &use : value.getUses()) { 265640a8bd63SAlex Zinenko if (!isValueUsePotentialConsumer(use)) 265740a8bd63SAlex Zinenko continue; 265840a8bd63SAlex Zinenko 265940a8bd63SAlex Zinenko if (!potentialConsumer) { 266040a8bd63SAlex Zinenko potentialConsumer = &use; 266140a8bd63SAlex Zinenko continue; 266240a8bd63SAlex Zinenko } 266340a8bd63SAlex Zinenko 266440a8bd63SAlex Zinenko InFlightDiagnostic diag = reportError() 266540a8bd63SAlex Zinenko << " has more than one potential consumer"; 266640a8bd63SAlex Zinenko diag.attachNote(potentialConsumer->getOwner()->getLoc()) 266740a8bd63SAlex Zinenko << "used here as operand #" << potentialConsumer->getOperandNumber(); 266840a8bd63SAlex Zinenko diag.attachNote(use.getOwner()->getLoc()) 266940a8bd63SAlex Zinenko << "used here as operand #" << use.getOperandNumber(); 267040a8bd63SAlex Zinenko return diag; 267140a8bd63SAlex Zinenko } 267240a8bd63SAlex Zinenko 267340a8bd63SAlex Zinenko return success(); 267440a8bd63SAlex Zinenko } 267540a8bd63SAlex Zinenko 26760eb403adSAlex Zinenko LogicalResult transform::SequenceOp::verify() { 2677b9e40cdeSAlex Zinenko assert(getBodyBlock()->getNumArguments() >= 1 && 2678b9e40cdeSAlex Zinenko "the number of arguments must have been verified to be more than 1 by " 2679df969f66SAlex Zinenko "PossibleTopLevelTransformOpTrait"); 2680df969f66SAlex Zinenko 2681b9e40cdeSAlex Zinenko if (!getRoot() && !getExtraBindings().empty()) { 2682b9e40cdeSAlex Zinenko return emitOpError() 2683b9e40cdeSAlex Zinenko << "does not expect extra operands when used as top-level"; 2684df969f66SAlex Zinenko } 2685df969f66SAlex Zinenko 2686b9e40cdeSAlex Zinenko // Check if a block argument has more than one consuming use. 2687b9e40cdeSAlex Zinenko for (BlockArgument arg : getBodyBlock()->getArguments()) { 2688b9e40cdeSAlex Zinenko if (failed(checkDoubleConsume(arg, [this, arg]() { 2689b9e40cdeSAlex Zinenko return (emitOpError() << "block argument #" << arg.getArgNumber()); 2690b9e40cdeSAlex Zinenko }))) { 269140a8bd63SAlex Zinenko return failure(); 269240a8bd63SAlex Zinenko } 2693b9e40cdeSAlex Zinenko } 269440a8bd63SAlex Zinenko 269540a8bd63SAlex Zinenko // Check properties of the nested operations they cannot check themselves. 26960eb403adSAlex Zinenko for (Operation &child : *getBodyBlock()) { 26970eb403adSAlex Zinenko if (!isa<TransformOpInterface>(child) && 26980eb403adSAlex Zinenko &child != &getBodyBlock()->back()) { 26990eb403adSAlex Zinenko InFlightDiagnostic diag = 27000eb403adSAlex Zinenko emitOpError() 27010eb403adSAlex Zinenko << "expected children ops to implement TransformOpInterface"; 27020eb403adSAlex Zinenko diag.attachNote(child.getLoc()) << "op without interface"; 27030eb403adSAlex Zinenko return diag; 27040eb403adSAlex Zinenko } 27050eb403adSAlex Zinenko 27060eb403adSAlex Zinenko for (OpResult result : child.getResults()) { 270740a8bd63SAlex Zinenko auto report = [&]() { 270840a8bd63SAlex Zinenko return (child.emitError() << "result #" << result.getResultNumber()); 270940a8bd63SAlex Zinenko }; 271040a8bd63SAlex Zinenko if (failed(checkDoubleConsume(result, report))) 271140a8bd63SAlex Zinenko return failure(); 27120eb403adSAlex Zinenko } 27130eb403adSAlex Zinenko } 27140eb403adSAlex Zinenko 271502981c96Svic if (!getBodyBlock()->mightHaveTerminator()) 2716a2a1dbb5SOleksandr "Alex" Zinenko return emitOpError() << "expects to have a terminator in the body"; 2717a2a1dbb5SOleksandr "Alex" Zinenko 27180eb403adSAlex Zinenko if (getBodyBlock()->getTerminator()->getOperandTypes() != 27190eb403adSAlex Zinenko getOperation()->getResultTypes()) { 27200eb403adSAlex Zinenko InFlightDiagnostic diag = emitOpError() 27210eb403adSAlex Zinenko << "expects the types of the terminator operands " 27220eb403adSAlex Zinenko "to match the types of the result"; 27230eb403adSAlex Zinenko diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 27240eb403adSAlex Zinenko return diag; 27250eb403adSAlex Zinenko } 27260eb403adSAlex Zinenko return success(); 27270eb403adSAlex Zinenko } 272830f22429SAlex Zinenko 27290242b962SAlex Zinenko void transform::SequenceOp::getEffects( 27300242b962SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 273194d608d4SAlex Zinenko getPotentialTopLevelEffects(effects); 273240a8bd63SAlex Zinenko } 273340a8bd63SAlex Zinenko 27344dd744acSMarkus Böck OperandRange 27354dd744acSMarkus Böck transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { 27364dd744acSMarkus Böck assert(point == getBody() && "unexpected region index"); 2737b9e40cdeSAlex Zinenko if (getOperation()->getNumOperands() > 0) 273873c3dff1SAlex Zinenko return getOperation()->getOperands(); 273973c3dff1SAlex Zinenko return OperandRange(getOperation()->operand_end(), 274073c3dff1SAlex Zinenko getOperation()->operand_end()); 274173c3dff1SAlex Zinenko } 274273c3dff1SAlex Zinenko 274373c3dff1SAlex Zinenko void transform::SequenceOp::getSuccessorRegions( 27444dd744acSMarkus Böck RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 27454dd744acSMarkus Böck if (point.isParent()) { 274673c3dff1SAlex Zinenko Region *bodyRegion = &getBody(); 2747138df298SMarkus Böck regions.emplace_back(bodyRegion, getNumOperands() != 0 274873c3dff1SAlex Zinenko ? bodyRegion->getArguments() 274973c3dff1SAlex Zinenko : Block::BlockArgListType()); 275073c3dff1SAlex Zinenko return; 275173c3dff1SAlex Zinenko } 275273c3dff1SAlex Zinenko 27534dd744acSMarkus Böck assert(point == getBody() && "unexpected region index"); 275473c3dff1SAlex Zinenko regions.emplace_back(getOperation()->getResults()); 275573c3dff1SAlex Zinenko } 275673c3dff1SAlex Zinenko 275773c3dff1SAlex Zinenko void transform::SequenceOp::getRegionInvocationBounds( 275873c3dff1SAlex Zinenko ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 275973c3dff1SAlex Zinenko (void)operands; 276073c3dff1SAlex Zinenko bounds.emplace_back(1, 1); 276173c3dff1SAlex Zinenko } 276273c3dff1SAlex Zinenko 276300c95b19SMatthias Springer void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 276400c95b19SMatthias Springer TypeRange resultTypes, 276500c95b19SMatthias Springer FailurePropagationMode failurePropagationMode, 276600c95b19SMatthias Springer Value root, 276700c95b19SMatthias Springer SequenceBodyBuilderFn bodyBuilder) { 2768b9e40cdeSAlex Zinenko build(builder, state, resultTypes, failurePropagationMode, root, 276901b9d355SAdrian Kuegel /*extra_bindings=*/ValueRange()); 2770df969f66SAlex Zinenko Type bbArgType = root.getType(); 2771b9e40cdeSAlex Zinenko buildSequenceBody(builder, state, bbArgType, 2772b9e40cdeSAlex Zinenko /*extraBindingTypes=*/TypeRange(), bodyBuilder); 2773b9e40cdeSAlex Zinenko } 277400c95b19SMatthias Springer 2775b9e40cdeSAlex Zinenko void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 2776b9e40cdeSAlex Zinenko TypeRange resultTypes, 2777b9e40cdeSAlex Zinenko FailurePropagationMode failurePropagationMode, 2778b9e40cdeSAlex Zinenko Value root, ValueRange extraBindings, 2779b9e40cdeSAlex Zinenko SequenceBodyBuilderArgsFn bodyBuilder) { 2780b9e40cdeSAlex Zinenko build(builder, state, resultTypes, failurePropagationMode, root, 2781b9e40cdeSAlex Zinenko extraBindings); 2782b9e40cdeSAlex Zinenko buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(), 2783b9e40cdeSAlex Zinenko bodyBuilder); 278400c95b19SMatthias Springer } 278500c95b19SMatthias Springer 278600c95b19SMatthias Springer void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 278700c95b19SMatthias Springer TypeRange resultTypes, 278800c95b19SMatthias Springer FailurePropagationMode failurePropagationMode, 278900c95b19SMatthias Springer Type bbArgType, 279000c95b19SMatthias Springer SequenceBodyBuilderFn bodyBuilder) { 2791b9e40cdeSAlex Zinenko build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), 279201b9d355SAdrian Kuegel /*extra_bindings=*/ValueRange()); 2793b9e40cdeSAlex Zinenko buildSequenceBody(builder, state, bbArgType, 2794b9e40cdeSAlex Zinenko /*extraBindingTypes=*/TypeRange(), bodyBuilder); 2795b9e40cdeSAlex Zinenko } 279600c95b19SMatthias Springer 2797b9e40cdeSAlex Zinenko void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, 2798b9e40cdeSAlex Zinenko TypeRange resultTypes, 2799b9e40cdeSAlex Zinenko FailurePropagationMode failurePropagationMode, 2800b9e40cdeSAlex Zinenko Type bbArgType, TypeRange extraBindingTypes, 2801b9e40cdeSAlex Zinenko SequenceBodyBuilderArgsFn bodyBuilder) { 2802b9e40cdeSAlex Zinenko build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), 280301b9d355SAdrian Kuegel /*extra_bindings=*/ValueRange()); 2804b9e40cdeSAlex Zinenko buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder); 280500c95b19SMatthias Springer } 280600c95b19SMatthias Springer 280730f22429SAlex Zinenko //===----------------------------------------------------------------------===// 28084b428364SMatthias Springer // PrintOp 28094b428364SMatthias Springer //===----------------------------------------------------------------------===// 28104b428364SMatthias Springer 2811c8fab80dSNicolas Vasilache void transform::PrintOp::build(OpBuilder &builder, OperationState &result, 2812c8fab80dSNicolas Vasilache StringRef name) { 2813214ce4daSJinyun (Joey) Ye if (!name.empty()) 2814214ce4daSJinyun (Joey) Ye result.getOrAddProperties<Properties>().name = builder.getStringAttr(name); 2815c8fab80dSNicolas Vasilache } 2816c8fab80dSNicolas Vasilache 2817c8fab80dSNicolas Vasilache void transform::PrintOp::build(OpBuilder &builder, OperationState &result, 2818c8fab80dSNicolas Vasilache Value target, StringRef name) { 2819c8fab80dSNicolas Vasilache result.addOperands({target}); 2820c8fab80dSNicolas Vasilache build(builder, result, name); 2821c8fab80dSNicolas Vasilache } 2822c8fab80dSNicolas Vasilache 28234b428364SMatthias Springer DiagnosedSilenceableFailure 2824c63d2b2cSMatthias Springer transform::PrintOp::apply(transform::TransformRewriter &rewriter, 2825c63d2b2cSMatthias Springer transform::TransformResults &results, 28264b428364SMatthias Springer transform::TransformState &state) { 282706ca5c81SNicolas Vasilache llvm::outs() << "[[[ IR printer: "; 28284b428364SMatthias Springer if (getName().has_value()) 282906ca5c81SNicolas Vasilache llvm::outs() << *getName() << " "; 28304b428364SMatthias Springer 2831a8cfa7cbSJakub Kuderski OpPrintingFlags printFlags; 2832a8cfa7cbSJakub Kuderski if (getAssumeVerified().value_or(false)) 2833a8cfa7cbSJakub Kuderski printFlags.assumeVerified(); 2834a8cfa7cbSJakub Kuderski if (getUseLocalScope().value_or(false)) 2835a8cfa7cbSJakub Kuderski printFlags.useLocalScope(); 2836a8cfa7cbSJakub Kuderski if (getSkipRegions().value_or(false)) 2837a8cfa7cbSJakub Kuderski printFlags.skipRegions(); 2838a8cfa7cbSJakub Kuderski 28394b428364SMatthias Springer if (!getTarget()) { 2840a8cfa7cbSJakub Kuderski llvm::outs() << "top-level ]]]\n"; 2841a8cfa7cbSJakub Kuderski state.getTopLevel()->print(llvm::outs(), printFlags); 2842a8cfa7cbSJakub Kuderski llvm::outs() << "\n"; 2843*f6bfbc87SOleksandr "Alex" Zinenko llvm::outs().flush(); 28444b428364SMatthias Springer return DiagnosedSilenceableFailure::success(); 28454b428364SMatthias Springer } 28464b428364SMatthias Springer 284706ca5c81SNicolas Vasilache llvm::outs() << "]]]\n"; 2848a8cfa7cbSJakub Kuderski for (Operation *target : state.getPayloadOps(getTarget())) { 2849a8cfa7cbSJakub Kuderski target->print(llvm::outs(), printFlags); 2850a8cfa7cbSJakub Kuderski llvm::outs() << "\n"; 2851a8cfa7cbSJakub Kuderski } 28524b428364SMatthias Springer 2853*f6bfbc87SOleksandr "Alex" Zinenko llvm::outs().flush(); 28544b428364SMatthias Springer return DiagnosedSilenceableFailure::success(); 28554b428364SMatthias Springer } 28564b428364SMatthias Springer 28574b428364SMatthias Springer void transform::PrintOp::getEffects( 28584b428364SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2859fa1b807bSOleksandr "Alex" Zinenko // We don't really care about mutability here, but `getTarget` now 2860fa1b807bSOleksandr "Alex" Zinenko // unconditionally casts to a specific type before verification could run 2861fa1b807bSOleksandr "Alex" Zinenko // here. 2862fa1b807bSOleksandr "Alex" Zinenko if (!getTargetMutable().empty()) 28632c1ae801Sdonald chen onlyReadsHandle(getTargetMutable()[0], effects); 28644b428364SMatthias Springer onlyReadsPayload(effects); 28654b428364SMatthias Springer 28664b428364SMatthias Springer // There is no resource for stderr file descriptor, so just declare print 28674b428364SMatthias Springer // writes into the default resource. 28684b428364SMatthias Springer effects.emplace_back(MemoryEffects::Write::get()); 28694b428364SMatthias Springer } 28700242b962SAlex Zinenko 28710242b962SAlex Zinenko //===----------------------------------------------------------------------===// 28727dfcd4b7SMatthias Springer // VerifyOp 28737dfcd4b7SMatthias Springer //===----------------------------------------------------------------------===// 28747dfcd4b7SMatthias Springer 28757dfcd4b7SMatthias Springer DiagnosedSilenceableFailure 28767dfcd4b7SMatthias Springer transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter, 28777dfcd4b7SMatthias Springer Operation *target, 28787dfcd4b7SMatthias Springer transform::ApplyToEachResultList &results, 28797dfcd4b7SMatthias Springer transform::TransformState &state) { 28807dfcd4b7SMatthias Springer if (failed(::mlir::verify(target))) { 28817dfcd4b7SMatthias Springer DiagnosedDefiniteFailure diag = emitDefiniteFailure() 28827dfcd4b7SMatthias Springer << "failed to verify payload op"; 28837dfcd4b7SMatthias Springer diag.attachNote(target->getLoc()) << "payload op"; 28847dfcd4b7SMatthias Springer return diag; 28857dfcd4b7SMatthias Springer } 28867dfcd4b7SMatthias Springer return DiagnosedSilenceableFailure::success(); 28877dfcd4b7SMatthias Springer } 28887dfcd4b7SMatthias Springer 28897dfcd4b7SMatthias Springer void transform::VerifyOp::getEffects( 28907dfcd4b7SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 28912c1ae801Sdonald chen transform::onlyReadsHandle(getTargetMutable(), effects); 28927dfcd4b7SMatthias Springer } 28937dfcd4b7SMatthias Springer 28947dfcd4b7SMatthias Springer //===----------------------------------------------------------------------===// 28950242b962SAlex Zinenko // YieldOp 28960242b962SAlex Zinenko //===----------------------------------------------------------------------===// 28970242b962SAlex Zinenko 28980242b962SAlex Zinenko void transform::YieldOp::getEffects( 28990242b962SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 29002c1ae801Sdonald chen onlyReadsHandle(getOperandsMutable(), effects); 29010242b962SAlex Zinenko } 2902