//===- TransformOps.cpp - Transform dialect operations --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/CSE.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include #define DEBUG_TYPE "transform-dialect" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") #define DEBUG_TYPE_MATCHER "transform-matcher" #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) using namespace mlir; static ParseResult parseSequenceOpOperands( OpAsmParser &parser, std::optional &root, Type &rootType, SmallVectorImpl &extraBindings, SmallVectorImpl &extraBindingTypes); static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes); static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions); static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions); /// Helper function to check if the given transform op is contained in (or /// equal to) the given payload target op. In that case, an error is returned. /// Transforming transform IR that is currently executing is generally unsafe. static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload) { Operation *transformAncestor = transform.getOperation(); while (transformAncestor) { if (transformAncestor == payload) { DiagnosedDefiniteFailure diag = transform.emitDefiniteFailure() << "cannot apply transform to itself (or one of its ancestors)"; diag.attachNote(payload->getLoc()) << "target payload op"; return diag; } transformAncestor = transformAncestor->getParentOp(); } return DiagnosedSilenceableFailure::success(); } #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" //===----------------------------------------------------------------------===// // AlternativesOp //===----------------------------------------------------------------------===// OperandRange transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { if (!point.isParent() && getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); } void transform::AlternativesOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { for (Region &alternative : llvm::drop_begin( getAlternatives(), point.isParent() ? 0 : point.getRegionOrNull()->getRegionNumber() + 1)) { regions.emplace_back(&alternative, !getOperands().empty() ? alternative.getArguments() : Block::BlockArgListType()); } if (!point.isParent()) regions.emplace_back(getOperation()->getResults()); } void transform::AlternativesOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &bounds) { (void)operands; // The region corresponding to the first alternative is always executed, the // remaining may or may not be executed. bounds.reserve(getNumRegions()); bounds.emplace_back(1, 1); bounds.resize(getNumRegions(), InvocationBounds(0, 1)); } static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results) { for (const auto &res : block->getParentOp()->getOpResults()) results.set(res, {}); } DiagnosedSilenceableFailure transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector originals; if (Value scopeHandle = getScope()) llvm::append_range(originals, state.getPayloadOps(scopeHandle)); else originals.push_back(state.getTopLevel()); for (Operation *original : originals) { if (original->isAncestor(getOperation())) { auto diag = emitDefiniteFailure() << "scope must not contain the transforms being applied"; diag.attachNote(original->getLoc()) << "scope"; return diag; } if (!original->hasTrait()) { auto diag = emitDefiniteFailure() << "only isolated-from-above ops can be alternative scopes"; diag.attachNote(original->getLoc()) << "scope"; return diag; } } for (Region ® : getAlternatives()) { // Clone the scope operations and make the transforms in this alternative // region apply to them by virtue of mapping the block argument (the only // visible handle) to the cloned scope operations. This effectively prevents // the transformation from accessing any IR outside the scope. auto scope = state.make_region_scope(reg); auto clones = llvm::to_vector( llvm::map_range(originals, [](Operation *op) { return op->clone(); })); auto deleteClones = llvm::make_scope_exit([&] { for (Operation *clone : clones) clone->erase(); }); if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) return DiagnosedSilenceableFailure::definiteFailure(); bool failed = false; for (Operation &transform : reg.front().without_terminator()) { DiagnosedSilenceableFailure result = state.applyTransform(cast(transform)); if (result.isSilenceableFailure()) { LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() << "\n"); failed = true; break; } if (::mlir::failed(result.silence())) return DiagnosedSilenceableFailure::definiteFailure(); } // If all operations in the given alternative succeeded, no need to consider // the rest. Replace the original scoping operation with the clone on which // the transformations were performed. if (!failed) { // We will be using the clones, so cancel their scheduled deletion. deleteClones.release(); TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); for (const auto &kvp : llvm::zip(originals, clones)) { Operation *original = std::get<0>(kvp); Operation *clone = std::get<1>(kvp); original->getBlock()->getOperations().insert(original->getIterator(), clone); rewriter.replaceOp(original, clone->getResults()); } detail::forwardTerminatorOperands(®.front(), state, results); return DiagnosedSilenceableFailure::success(); } } return emitSilenceableError() << "all alternatives failed"; } void transform::AlternativesOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getOperation()->getOpOperands(), effects); producesHandle(getOperation()->getOpResults(), effects); for (Region *region : getRegions()) { if (!region->empty()) producesHandle(region->front().getArguments(), effects); } modifiesPayload(effects); } LogicalResult transform::AlternativesOp::verify() { for (Region &alternative : getAlternatives()) { Block &block = alternative.front(); Operation *terminator = block.getTerminator(); if (terminator->getOperands().getTypes() != getResults().getTypes()) { InFlightDiagnostic diag = emitOpError() << "expects terminator operands to have the " "same type as results of the operation"; diag.attachNote(terminator->getLoc()) << "terminator"; return diag; } } return success(); } //===----------------------------------------------------------------------===// // AnnotateOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::AnnotateOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector targets = llvm::to_vector(state.getPayloadOps(getTarget())); Attribute attr = UnitAttr::get(getContext()); if (auto paramH = getParam()) { ArrayRef params = state.getParams(paramH); if (params.size() != 1) { if (targets.size() != params.size()) { return emitSilenceableError() << "parameter and target have different payload lengths (" << params.size() << " vs " << targets.size() << ")"; } for (auto &&[target, attr] : llvm::zip_equal(targets, params)) target->setAttr(getName(), attr); return DiagnosedSilenceableFailure::success(); } attr = params[0]; } for (auto *target : targets) target->setAttr(getName(), attr); return DiagnosedSilenceableFailure::success(); } void transform::AnnotateOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getTargetMutable(), effects); onlyReadsHandle(getParamMutable(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // ApplyCommonSubexpressionEliminationOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ApplyCommonSubexpressionEliminationOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { // Make sure that this transform is not applied to itself. Modifying the // transform IR while it is being interpreted is generally dangerous. DiagnosedSilenceableFailure payloadCheck = ensurePayloadIsSeparateFromTransform(*this, target); if (!payloadCheck.succeeded()) return payloadCheck; DominanceInfo domInfo; mlir::eliminateCommonSubExpressions(rewriter, domInfo, target); return DiagnosedSilenceableFailure::success(); } void transform::ApplyCommonSubexpressionEliminationOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTargetMutable(), effects); transform::modifiesPayload(effects); } //===----------------------------------------------------------------------===// // ApplyDeadCodeEliminationOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { // Make sure that this transform is not applied to itself. Modifying the // transform IR while it is being interpreted is generally dangerous. DiagnosedSilenceableFailure payloadCheck = ensurePayloadIsSeparateFromTransform(*this, target); if (!payloadCheck.succeeded()) return payloadCheck; // Maintain a worklist of potentially dead ops. SetVector worklist; // Helper function that adds all defining ops of used values (operands and // operands of nested ops). auto addDefiningOpsToWorklist = [&](Operation *op) { op->walk([&](Operation *op) { for (Value v : op->getOperands()) if (Operation *defOp = v.getDefiningOp()) if (target->isProperAncestor(defOp)) worklist.insert(defOp); }); }; // Helper function that erases an op. auto eraseOp = [&](Operation *op) { // Remove op and nested ops from the worklist. op->walk([&](Operation *op) { const auto *it = llvm::find(worklist, op); if (it != worklist.end()) worklist.erase(it); }); rewriter.eraseOp(op); }; // Initial walk over the IR. target->walk([&](Operation *op) { if (op != target && isOpTriviallyDead(op)) { addDefiningOpsToWorklist(op); eraseOp(op); } }); // Erase all ops that have become dead. while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); if (!isOpTriviallyDead(op)) continue; addDefiningOpsToWorklist(op); eraseOp(op); } return DiagnosedSilenceableFailure::success(); } void transform::ApplyDeadCodeEliminationOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTargetMutable(), effects); transform::modifiesPayload(effects); } //===----------------------------------------------------------------------===// // ApplyPatternsOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { // Make sure that this transform is not applied to itself. Modifying the // transform IR while it is being interpreted is generally dangerous. Even // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver // performs many additional simplifications such as dead code elimination. DiagnosedSilenceableFailure payloadCheck = ensurePayloadIsSeparateFromTransform(*this, target); if (!payloadCheck.succeeded()) return payloadCheck; // Gather all specified patterns. MLIRContext *ctx = target->getContext(); RewritePatternSet patterns(ctx); if (!getRegion().empty()) { for (Operation &op : getRegion().front()) { cast(&op) .populatePatternsWithState(patterns, state); } } // Configure the GreedyPatternRewriteDriver. GreedyRewriteConfig config; config.listener = static_cast(rewriter.getListener()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); config.maxIterations = getMaxIterations() == static_cast(-1) ? GreedyRewriteConfig::kNoLimit : getMaxIterations(); config.maxNumRewrites = getMaxNumRewrites() == static_cast(-1) ? GreedyRewriteConfig::kNoLimit : getMaxNumRewrites(); // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE // was requested, apply the greedy pattern rewrite only once. (The greedy // pattern rewrite driver already iterates to a fixpoint internally.) bool cseChanged = false; // One or two iterations should be sufficient. Stop iterating after a certain // threshold to make debugging easier. static const int64_t kNumMaxIterations = 50; int64_t iteration = 0; do { LogicalResult result = failure(); if (target->hasTrait()) { // Op is isolated from above. Apply patterns and also perform region // simplification. result = applyPatternsGreedily(target, frozenPatterns, config); } else { // Manually gather list of ops because the other // GreedyPatternRewriteDriver overloads only accepts ops that are isolated // from above. This way, patterns can be applied to ops that are not // isolated from above. Regions are not being simplified. Furthermore, // only a single greedy rewrite iteration is performed. SmallVector ops; target->walk([&](Operation *nestedOp) { if (target != nestedOp) ops.push_back(nestedOp); }); result = applyOpPatternsGreedily(ops, frozenPatterns, config); } // A failure typically indicates that the pattern application did not // converge. if (failed(result)) { return emitSilenceableFailure(target) << "greedy pattern application failed"; } if (getApplyCse()) { DominanceInfo domInfo; mlir::eliminateCommonSubExpressions(rewriter, domInfo, target, &cseChanged); } } while (cseChanged && ++iteration < kNumMaxIterations); if (iteration == kNumMaxIterations) return emitDefiniteFailure() << "fixpoint iteration did not converge"; return DiagnosedSilenceableFailure::success(); } LogicalResult transform::ApplyPatternsOp::verify() { if (!getRegion().empty()) { for (Operation &op : getRegion().front()) { if (!isa(&op)) { InFlightDiagnostic diag = emitOpError() << "expected children ops to implement " "PatternDescriptorOpInterface"; diag.attachNote(op.getLoc()) << "op without interface"; return diag; } } } return success(); } void transform::ApplyPatternsOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTargetMutable(), effects); transform::modifiesPayload(effects); } void transform::ApplyPatternsOp::build( OpBuilder &builder, OperationState &result, Value target, function_ref bodyBuilder) { result.addOperands(target); OpBuilder::InsertionGuard g(builder); Region *region = result.addRegion(); builder.createBlock(region); if (bodyBuilder) bodyBuilder(builder, result.location); } //===----------------------------------------------------------------------===// // ApplyCanonicalizationPatternsOp //===----------------------------------------------------------------------===// void transform::ApplyCanonicalizationPatternsOp::populatePatterns( RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); for (Dialect *dialect : ctx->getLoadedDialects()) dialect->getCanonicalizationPatterns(patterns); for (RegisteredOperationName op : ctx->getRegisteredOperations()) op.getCanonicalizationPatterns(patterns, ctx); } //===----------------------------------------------------------------------===// // ApplyConversionPatternsOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { MLIRContext *ctx = getContext(); // Instantiate the default type converter if a type converter builder is // specified. std::unique_ptr defaultTypeConverter; transform::TypeConverterBuilderOpInterface typeConverterBuilder = getDefaultTypeConverter(); if (typeConverterBuilder) defaultTypeConverter = typeConverterBuilder.getTypeConverter(); // Configure conversion target. ConversionTarget conversionTarget(*getContext()); if (getLegalOps()) for (Attribute attr : cast(*getLegalOps())) conversionTarget.addLegalOp( OperationName(cast(attr).getValue(), ctx)); if (getIllegalOps()) for (Attribute attr : cast(*getIllegalOps())) conversionTarget.addIllegalOp( OperationName(cast(attr).getValue(), ctx)); if (getLegalDialects()) for (Attribute attr : cast(*getLegalDialects())) conversionTarget.addLegalDialect(cast(attr).getValue()); if (getIllegalDialects()) for (Attribute attr : cast(*getIllegalDialects())) conversionTarget.addIllegalDialect(cast(attr).getValue()); // Gather all specified patterns. RewritePatternSet patterns(ctx); // Need to keep the converters alive until after pattern application because // the patterns take a reference to an object that would otherwise get out of // scope. SmallVector> keepAliveConverters; if (!getPatterns().empty()) { for (Operation &op : getPatterns().front()) { auto descriptor = cast(&op); // Check if this pattern set specifies a type converter. std::unique_ptr typeConverter = descriptor.getTypeConverter(); TypeConverter *converter = nullptr; if (typeConverter) { keepAliveConverters.emplace_back(std::move(typeConverter)); converter = keepAliveConverters.back().get(); } else { // No type converter specified: Use the default type converter. if (!defaultTypeConverter) { auto diag = emitDefiniteFailure() << "pattern descriptor does not specify type " "converter and apply_conversion_patterns op has " "no default type converter"; diag.attachNote(op.getLoc()) << "pattern descriptor op"; return diag; } converter = defaultTypeConverter.get(); } // Add descriptor-specific updates to the conversion target, which may // depend on the final type converter. In structural converters, the // legality of types dictates the dynamic legality of an operation. descriptor.populateConversionTargetRules(*converter, conversionTarget); descriptor.populatePatterns(*converter, patterns); } } // Attach a tracking listener if handles should be preserved. We configure the // listener to allow op replacements with different names, as conversion // patterns typically replace ops with replacement ops that have a different // name. TrackingListenerConfig trackingConfig; trackingConfig.requireMatchingReplacementOpName = false; ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig); ConversionConfig conversionConfig; if (getPreserveHandles()) conversionConfig.listener = &trackingListener; FrozenRewritePatternSet frozenPatterns(std::move(patterns)); for (Operation *target : state.getPayloadOps(getTarget())) { // Make sure that this transform is not applied to itself. Modifying the // transform IR while it is being interpreted is generally dangerous. DiagnosedSilenceableFailure payloadCheck = ensurePayloadIsSeparateFromTransform(*this, target); if (!payloadCheck.succeeded()) return payloadCheck; LogicalResult status = failure(); if (getPartialConversion()) { status = applyPartialConversion(target, conversionTarget, frozenPatterns, conversionConfig); } else { status = applyFullConversion(target, conversionTarget, frozenPatterns, conversionConfig); } // Check dialect conversion state. DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); if (failed(status)) { diag = emitSilenceableError() << "dialect conversion failed"; diag.attachNote(target->getLoc()) << "target op"; } // Check tracking listener error state. DiagnosedSilenceableFailure trackingFailure = trackingListener.checkAndResetError(); if (!trackingFailure.succeeded()) { if (diag.succeeded()) { // Tracking failure is the only failure. return trackingFailure; } else { diag.attachNote() << "tracking listener also failed: " << trackingFailure.getMessage(); (void)trackingFailure.silence(); } } if (!diag.succeeded()) return diag; } return DiagnosedSilenceableFailure::success(); } LogicalResult transform::ApplyConversionPatternsOp::verify() { if (getNumRegions() != 1 && getNumRegions() != 2) return emitOpError() << "expected 1 or 2 regions"; if (!getPatterns().empty()) { for (Operation &op : getPatterns().front()) { if (!isa(&op)) { InFlightDiagnostic diag = emitOpError() << "expected pattern children ops to implement " "ConversionPatternDescriptorOpInterface"; diag.attachNote(op.getLoc()) << "op without interface"; return diag; } } } if (getNumRegions() == 2) { Region &typeConverterRegion = getRegion(1); if (!llvm::hasSingleElement(typeConverterRegion.front())) return emitOpError() << "expected exactly one op in default type converter region"; Operation *maybeTypeConverter = &typeConverterRegion.front().front(); auto typeConverterOp = dyn_cast( maybeTypeConverter); if (!typeConverterOp) { InFlightDiagnostic diag = emitOpError() << "expected default converter child op to " "implement TypeConverterBuilderOpInterface"; diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface"; return diag; } // Check default type converter type. if (!getPatterns().empty()) { for (Operation &op : getPatterns().front()) { auto descriptor = cast(&op); if (failed(descriptor.verifyTypeConverter(typeConverterOp))) return failure(); } } } return success(); } void transform::ApplyConversionPatternsOp::getEffects( SmallVectorImpl &effects) { if (!getPreserveHandles()) { transform::consumesHandle(getTargetMutable(), effects); } else { transform::onlyReadsHandle(getTargetMutable(), effects); } transform::modifiesPayload(effects); } void transform::ApplyConversionPatternsOp::build( OpBuilder &builder, OperationState &result, Value target, function_ref patternsBodyBuilder, function_ref typeConverterBodyBuilder) { result.addOperands(target); { OpBuilder::InsertionGuard g(builder); Region *region1 = result.addRegion(); builder.createBlock(region1); if (patternsBodyBuilder) patternsBodyBuilder(builder, result.location); } { OpBuilder::InsertionGuard g(builder); Region *region2 = result.addRegion(); builder.createBlock(region2); if (typeConverterBodyBuilder) typeConverterBodyBuilder(builder, result.location); } } //===----------------------------------------------------------------------===// // ApplyToLLVMConversionPatternsOp //===----------------------------------------------------------------------===// void transform::ApplyToLLVMConversionPatternsOp::populatePatterns( TypeConverter &typeConverter, RewritePatternSet &patterns) { Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); assert(dialect && "expected that dialect is loaded"); auto *iface = cast(dialect); // ConversionTarget is currently ignored because the enclosing // apply_conversion_patterns op sets up its own ConversionTarget. ConversionTarget target(*getContext()); iface->populateConvertToLLVMConversionPatterns( target, static_cast(typeConverter), patterns); } LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter( transform::TypeConverterBuilderOpInterface builder) { if (builder.getTypeConverterType() != "LLVMTypeConverter") return emitOpError("expected LLVMTypeConverter"); return success(); } LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() { Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); if (!dialect) return emitOpError("unknown dialect or dialect not loaded: ") << getDialectName(); auto *iface = dyn_cast(dialect); if (!iface) return emitOpError( "dialect does not implement ConvertToLLVMPatternInterface or " "extension was not loaded: ") << getDialectName(); return success(); } //===----------------------------------------------------------------------===// // ApplyLoopInvariantCodeMotionOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ApplyLoopInvariantCodeMotionOp::applyToOne( transform::TransformRewriter &rewriter, LoopLikeOpInterface target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Currently, LICM does not remove operations, so we don't need tracking. // If this ever changes, add a LICM entry point that takes a rewriter. moveLoopInvariantCode(target); return DiagnosedSilenceableFailure::success(); } void transform::ApplyLoopInvariantCodeMotionOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTargetMutable(), effects); transform::modifiesPayload(effects); } //===----------------------------------------------------------------------===// // ApplyRegisteredPassOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { // Make sure that this transform is not applied to itself. Modifying the // transform IR while it is being interpreted is generally dangerous. Even // more so when applying passes because they may perform a wide range of IR // modifications. DiagnosedSilenceableFailure payloadCheck = ensurePayloadIsSeparateFromTransform(*this, target); if (!payloadCheck.succeeded()) return payloadCheck; // Get pass or pass pipeline from registry. const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName()); if (!info) info = PassInfo::lookup(getPassName()); if (!info) return emitDefiniteFailure() << "unknown pass or pass pipeline: " << getPassName(); // Create pass manager and run the pass or pass pipeline. PassManager pm(getContext()); if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) { emitError(msg); return failure(); }))) { return emitDefiniteFailure() << "failed to add pass or pass pipeline to pipeline: " << getPassName(); } if (failed(pm.run(target))) { auto diag = emitSilenceableError() << "pass pipeline failed"; diag.attachNote(target->getLoc()) << "target op"; return diag; } results.push_back(target); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::CastOp::applyToOne(transform::TransformRewriter &rewriter, Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { results.push_back(target); return DiagnosedSilenceableFailure::success(); } void transform::CastOp::getEffects( SmallVectorImpl &effects) { onlyReadsPayload(effects); onlyReadsHandle(getInputMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); } bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { assert(inputs.size() == 1 && "expected one input"); assert(outputs.size() == 1 && "expected one output"); return llvm::all_of( std::initializer_list{inputs.front(), outputs.front()}, llvm::IsaPred); } //===----------------------------------------------------------------------===// // CollectMatchingOp //===----------------------------------------------------------------------===// /// Applies matcher operations from the given `block` using /// `blockArgumentMapping` to initialize block arguments. Updates `state` /// accordingly. If any of the matcher produces a silenceable failure, discards /// it (printing the content to the debug output stream) and returns failure. If /// any of the matchers produces a definite failure, reports it and returns /// failure. If all matchers in the block succeed, populates `mappings` with the /// payload entities associated with the block terminator operands. Note that /// `mappings` will be cleared before that. static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef> blockArgumentMapping, transform::TransformState &state, SmallVectorImpl> &mappings) { assert(block.getParent() && "cannot match using a detached block"); auto matchScope = state.make_region_scope(*block.getParent()); if (failed( state.mapBlockArguments(block.getArguments(), blockArgumentMapping))) return DiagnosedSilenceableFailure::definiteFailure(); for (Operation &match : block.without_terminator()) { if (!isa(match)) { return emitDefiniteFailure(match.getLoc()) << "expected operations in the match part to " "implement MatchOpInterface"; } DiagnosedSilenceableFailure diag = state.applyTransform(cast(match)); if (diag.succeeded()) continue; return diag; } // Remember the values mapped to the terminator operands so we can // forward them to the action. ValueRange yieldedValues = block.getTerminator()->getOperands(); // Our contract with the caller is that the mappings will contain only the // newly mapped values, clear the rest. mappings.clear(); transform::detail::prepareValueMappings(mappings, yieldedValues, state); return DiagnosedSilenceableFailure::success(); } /// Returns `true` if both types implement one of the interfaces provided as /// template parameters. template static bool implementSameInterface(Type t1, Type t2) { return ((isa(t1) && isa(t2)) || ... || false); } /// Returns `true` if both types implement one of the transform dialect /// interfaces. static bool implementSameTransformInterface(Type t1, Type t2) { return implementSameInterface( t1, t2); } //===----------------------------------------------------------------------===// // CollectMatchingOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto matcher = SymbolTable::lookupNearestSymbolFrom( getOperation(), getMatcher()); if (matcher.isExternal()) { return emitDefiniteFailure() << "unresolved external symbol " << getMatcher(); } SmallVector, 2> rawResults; rawResults.resize(getOperation()->getNumResults()); std::optional maybeFailure; for (Operation *root : state.getPayloadOps(getRoot())) { WalkResult walkResult = root->walk([&](Operation *op) { DEBUG_MATCHER({ DBGS_MATCHER() << "matching "; op->print(llvm::dbgs(), OpPrintingFlags().assumeVerified().skipRegions()); llvm::dbgs() << " @" << op << "\n"; }); // Try matching. SmallVector> mappings; SmallVector inputMapping({op}); DiagnosedSilenceableFailure diag = matchBlock( matcher.getFunctionBody().front(), ArrayRef>(inputMapping), state, mappings); if (diag.isDefiniteFailure()) return WalkResult::interrupt(); if (diag.isSilenceableFailure()) { DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() << " failed: " << diag.getMessage()); return WalkResult::advance(); } // If succeeded, collect results. for (auto &&[i, mapping] : llvm::enumerate(mappings)) { if (mapping.size() != 1) { maybeFailure.emplace(emitSilenceableError() << "result #" << i << ", associated with " << mapping.size() << " payload objects, expected 1"); return WalkResult::interrupt(); } rawResults[i].push_back(mapping[0]); } return WalkResult::advance(); }); if (walkResult.wasInterrupted()) return std::move(*maybeFailure); assert(!maybeFailure && "failure set but the walk was not interrupted"); for (auto &&[opResult, rawResult] : llvm::zip_equal(getOperation()->getResults(), rawResults)) { results.setMappedValues(opResult, rawResult); } } return DiagnosedSilenceableFailure::success(); } void transform::CollectMatchingOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getRootMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); onlyReadsPayload(effects); } LogicalResult transform::CollectMatchingOp::verifySymbolUses( SymbolTableCollection &symbolTable) { auto matcherSymbol = dyn_cast_or_null( symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher())); if (!matcherSymbol || !isa(matcherSymbol.getOperation())) return emitError() << "unresolved matcher symbol " << getMatcher(); ArrayRef argumentTypes = matcherSymbol.getArgumentTypes(); if (argumentTypes.size() != 1 || !isa(argumentTypes[0])) { return emitError() << "expected the matcher to take one operation handle argument"; } if (!matcherSymbol.getArgAttr( 0, transform::TransformDialect::kArgReadOnlyAttrName)) { return emitError() << "expected the matcher argument to be marked readonly"; } ArrayRef resultTypes = matcherSymbol.getResultTypes(); if (resultTypes.size() != getOperation()->getNumResults()) { return emitError() << "expected the matcher to yield as many values as op has results (" << getOperation()->getNumResults() << "), got " << resultTypes.size(); } for (auto &&[i, matcherType, resultType] : llvm::enumerate(resultTypes, getOperation()->getResultTypes())) { if (implementSameTransformInterface(matcherType, resultType)) continue; return emitError() << "mismatching type interfaces for matcher result and op result #" << i; } return success(); } //===----------------------------------------------------------------------===// // ForeachMatchOp //===----------------------------------------------------------------------===// // This is fine because nothing is actually consumed by this op. bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; } DiagnosedSilenceableFailure transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector> matchActionPairs; matchActionPairs.reserve(getMatchers().size()); SymbolTableCollection symbolTable; for (auto &&[matcher, action] : llvm::zip_equal(getMatchers(), getActions())) { auto matcherSymbol = symbolTable.lookupNearestSymbolFrom( getOperation(), cast(matcher)); auto actionSymbol = symbolTable.lookupNearestSymbolFrom( getOperation(), cast(action)); assert(matcherSymbol && actionSymbol && "unresolved symbols not caught by the verifier"); if (matcherSymbol.isExternal()) return emitDefiniteFailure() << "unresolved external symbol " << matcher; if (actionSymbol.isExternal()) return emitDefiniteFailure() << "unresolved external symbol " << action; matchActionPairs.emplace_back(matcherSymbol, actionSymbol); } DiagnosedSilenceableFailure overallDiag = DiagnosedSilenceableFailure::success(); SmallVector> matchInputMapping; SmallVector> matchOutputMapping; SmallVector> actionResultMapping; // Explicitly add the mapping for the first block argument (the op being // matched). matchInputMapping.emplace_back(); transform::detail::prepareValueMappings(matchInputMapping, getForwardedInputs(), state); SmallVector &firstMatchArgument = matchInputMapping.front(); actionResultMapping.resize(getForwardedOutputs().size()); for (Operation *root : state.getPayloadOps(getRoot())) { WalkResult walkResult = root->walk([&](Operation *op) { // If getRestrictRoot is not present, skip over the root op itself so we // don't invalidate it. if (!getRestrictRoot() && op == root) return WalkResult::advance(); DEBUG_MATCHER({ DBGS_MATCHER() << "matching "; op->print(llvm::dbgs(), OpPrintingFlags().assumeVerified().skipRegions()); llvm::dbgs() << " @" << op << "\n"; }); firstMatchArgument.clear(); firstMatchArgument.push_back(op); // Try all the match/action pairs until the first successful match. for (auto [matcher, action] : matchActionPairs) { DiagnosedSilenceableFailure diag = matchBlock(matcher.getFunctionBody().front(), matchInputMapping, state, matchOutputMapping); if (diag.isDefiniteFailure()) return WalkResult::interrupt(); if (diag.isSilenceableFailure()) { DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() << " failed: " << diag.getMessage()); continue; } auto scope = state.make_region_scope(action.getFunctionBody()); if (failed(state.mapBlockArguments( action.getFunctionBody().front().getArguments(), matchOutputMapping))) { return WalkResult::interrupt(); } for (Operation &transform : action.getFunctionBody().front().without_terminator()) { DiagnosedSilenceableFailure result = state.applyTransform(cast(transform)); if (result.isDefiniteFailure()) return WalkResult::interrupt(); if (result.isSilenceableFailure()) { if (overallDiag.succeeded()) { overallDiag = emitSilenceableError() << "actions failed"; } overallDiag.attachNote(action->getLoc()) << "failed action: " << result.getMessage(); overallDiag.attachNote(op->getLoc()) << "when applied to this matching payload"; (void)result.silence(); continue; } } if (failed(detail::appendValueMappings( MutableArrayRef>(actionResultMapping), action.getFunctionBody().front().getTerminator()->getOperands(), state, getFlattenResults()))) { emitDefiniteFailure() << "action @" << action.getName() << " has results associated with multiple payload entities, " "but flattening was not requested"; return WalkResult::interrupt(); } break; } return WalkResult::advance(); }); if (walkResult.wasInterrupted()) return DiagnosedSilenceableFailure::definiteFailure(); } // The root operation should not have been affected, so we can just reassign // the payload to the result. Note that we need to consume the root handle to // make sure any handles to operations inside, that could have been affected // by actions, are invalidated. results.set(llvm::cast(getUpdated()), state.getPayloadOps(getRoot())); for (auto &&[result, mapping] : llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) { results.setMappedValues(result, mapping); } return overallDiag; } void transform::ForeachMatchOp::getAsmResultNames( OpAsmSetValueNameFn setNameFn) { setNameFn(getUpdated(), "updated_root"); for (Value v : getForwardedOutputs()) { setNameFn(v, "yielded"); } } void transform::ForeachMatchOp::getEffects( SmallVectorImpl &effects) { // Bail if invalid. if (getOperation()->getNumOperands() < 1 || getOperation()->getNumResults() < 1) { return modifiesPayload(effects); } consumesHandle(getRootMutable(), effects); onlyReadsHandle(getForwardedInputsMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); } /// Parses the comma-separated list of symbol reference pairs of the format /// `@matcher -> @action`. static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions) { StringAttr matcher; StringAttr action; SmallVector matcherList; SmallVector actionList; do { if (parser.parseSymbolName(matcher) || parser.parseArrow() || parser.parseSymbolName(action)) { return failure(); } matcherList.push_back(SymbolRefAttr::get(matcher)); actionList.push_back(SymbolRefAttr::get(action)); } while (parser.parseOptionalComma().succeeded()); matchers = parser.getBuilder().getArrayAttr(matcherList); actions = parser.getBuilder().getArrayAttr(actionList); return success(); } /// Prints the comma-separated list of symbol reference pairs of the format /// `@matcher -> @action`. static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions) { printer.increaseIndent(); printer.increaseIndent(); for (auto &&[matcher, action, idx] : llvm::zip_equal( matchers, actions, llvm::seq(0, matchers.size()))) { printer.printNewline(); printer << cast(matcher) << " -> " << cast(action); if (idx != matchers.size() - 1) printer << ", "; } printer.decreaseIndent(); printer.decreaseIndent(); } LogicalResult transform::ForeachMatchOp::verify() { if (getMatchers().size() != getActions().size()) return emitOpError() << "expected the same number of matchers and actions"; if (getMatchers().empty()) return emitOpError() << "expected at least one match/action pair"; llvm::SmallPtrSet matcherNames; for (Attribute name : getMatchers()) { if (matcherNames.insert(name).second) continue; emitWarning() << "matcher " << name << " is used more than once, only the first match will apply"; } return success(); } /// Checks that the attributes of the function-like operation have correct /// consumption effect annotations. If `alsoVerifyInternal`, checks for /// annotations being present even if they can be inferred from the body. static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal = false) { auto transformOp = cast(op.getOperation()); llvm::SmallDenseSet consumedArguments; if (!op.isExternal()) { transform::getConsumedBlockArguments(op.getFunctionBody().front(), consumedArguments); } for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) { bool isConsumed = op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) != nullptr; bool isReadOnly = op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) != nullptr; if (isConsumed && isReadOnly) { return transformOp.emitSilenceableError() << "argument #" << i << " cannot be both readonly and consumed"; } if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) { return transformOp.emitSilenceableError() << "must provide consumed/readonly status for arguments of " "external or called ops"; } if (op.isExternal()) continue; if (consumedArguments.contains(i) && !isConsumed && isReadOnly) { return transformOp.emitSilenceableError() << "argument #" << i << " is consumed in the body but is not marked as such"; } if (emitWarnings && !consumedArguments.contains(i) && isConsumed) { // Cannot use op.emitWarning() here as it would attempt to verify the op // before printing, resulting in infinite recursion. emitWarning(op->getLoc()) << "op argument #" << i << " is not consumed in the body but is marked as consumed"; } } return DiagnosedSilenceableFailure::success(); } LogicalResult transform::ForeachMatchOp::verifySymbolUses( SymbolTableCollection &symbolTable) { assert(getMatchers().size() == getActions().size()); auto consumedAttr = StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName); for (auto &&[matcher, action] : llvm::zip_equal(getMatchers(), getActions())) { // Presence and typing. auto matcherSymbol = dyn_cast_or_null( symbolTable.lookupNearestSymbolFrom(getOperation(), cast(matcher))); auto actionSymbol = dyn_cast_or_null( symbolTable.lookupNearestSymbolFrom(getOperation(), cast(action))); if (!matcherSymbol || !isa(matcherSymbol.getOperation())) return emitError() << "unresolved matcher symbol " << matcher; if (!actionSymbol || !isa(actionSymbol.getOperation())) return emitError() << "unresolved action symbol " << action; if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol, /*emitWarnings=*/false, /*alsoVerifyInternal=*/true) .checkAndReport())) { return failure(); } if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol, /*emitWarnings=*/false, /*alsoVerifyInternal=*/true) .checkAndReport())) { return failure(); } // Input -> matcher forwarding. TypeRange operandTypes = getOperandTypes(); TypeRange matcherArguments = matcherSymbol.getArgumentTypes(); if (operandTypes.size() != matcherArguments.size()) { InFlightDiagnostic diag = emitError() << "the number of operands (" << operandTypes.size() << ") doesn't match the number of matcher arguments (" << matcherArguments.size() << ") for " << matcher; diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; return diag; } for (auto &&[i, operand, argument] : llvm::enumerate(operandTypes, matcherArguments)) { if (matcherSymbol.getArgAttr(i, consumedAttr)) { InFlightDiagnostic diag = emitOpError() << "does not expect matcher symbol to consume its operand #" << i; diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; return diag; } if (implementSameTransformInterface(operand, argument)) continue; InFlightDiagnostic diag = emitError() << "mismatching type interfaces for operand and matcher argument #" << i << " of matcher " << matcher; diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; return diag; } // Matcher -> action forwarding. TypeRange matcherResults = matcherSymbol.getResultTypes(); TypeRange actionArguments = actionSymbol.getArgumentTypes(); if (matcherResults.size() != actionArguments.size()) { return emitError() << "mismatching number of matcher results and " "action arguments between " << matcher << " (" << matcherResults.size() << ") and " << action << " (" << actionArguments.size() << ")"; } for (auto &&[i, matcherType, actionType] : llvm::enumerate(matcherResults, actionArguments)) { if (implementSameTransformInterface(matcherType, actionType)) continue; return emitError() << "mismatching type interfaces for matcher result " "and action argument #" << i << "of matcher " << matcher << " and action " << action; } // Action -> result forwarding. TypeRange actionResults = actionSymbol.getResultTypes(); auto resultTypes = TypeRange(getResultTypes()).drop_front(); if (actionResults.size() != resultTypes.size()) { InFlightDiagnostic diag = emitError() << "the number of action results (" << actionResults.size() << ") for " << action << " doesn't match the number of extra op results (" << resultTypes.size() << ")"; diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; return diag; } for (auto &&[i, resultType, actionType] : llvm::enumerate(resultTypes, actionResults)) { if (implementSameTransformInterface(resultType, actionType)) continue; InFlightDiagnostic diag = emitError() << "mismatching type interfaces for action result #" << i << " of action " << action << " and op result"; diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; return diag; } } return success(); } //===----------------------------------------------------------------------===// // ForeachOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ForeachOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { // We store the payloads before executing the body as ops may be removed from // the mapping by the TrackingRewriter while iteration is in progress. SmallVector> payloads; detail::prepareValueMappings(payloads, getTargets(), state); size_t numIterations = payloads.empty() ? 0 : payloads.front().size(); bool withZipShortest = getWithZipShortest(); // In case of `zip_shortest`, set the number of iterations to the // smallest payload in the targets. if (withZipShortest) { numIterations = llvm::min_element(payloads, [&](const SmallVector &A, const SmallVector &B) { return A.size() < B.size(); })->size(); for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++) payloads[argIdx].resize(numIterations); } // As we will be "zipping" over them, check all payloads have the same size. // `zip_shortest` adjusts all payloads to the same size, so skip this check // when true. for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size(); argIdx++) { if (payloads[argIdx].size() != numIterations) { return emitSilenceableError() << "prior targets' payload size (" << numIterations << ") differs from payload size (" << payloads[argIdx].size() << ") of target " << getTargets()[argIdx]; } } // Start iterating, indexing into payloads to obtain the right arguments to // call the body with - each slice of payloads at the same argument index // corresponding to a tuple to use as the body's block arguments. ArrayRef blockArguments = getBody().front().getArguments(); SmallVector> zippedResults(getNumResults(), {}); for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) { auto scope = state.make_region_scope(getBody()); // Set up arguments to the region's block. for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) { MappedValue argument = payloads[argIdx][iterIdx]; // Note that each blockArg's handle gets associated with just a single // element from the corresponding target's payload. if (failed(state.mapBlockArgument(blockArg, {argument}))) return DiagnosedSilenceableFailure::definiteFailure(); } // Execute loop body. for (Operation &transform : getBody().front().without_terminator()) { DiagnosedSilenceableFailure result = state.applyTransform( llvm::cast(transform)); if (!result.succeeded()) return result; } // Append yielded payloads to corresponding results from prior iterations. OperandRange yieldOperands = getYieldOp().getOperands(); for (auto &&[result, yieldOperand, resTuple] : llvm::zip_equal(getResults(), yieldOperands, zippedResults)) // NB: each iteration we add any number of ops/vals/params to a result. if (isa(result.getType())) llvm::append_range(resTuple, state.getPayloadOps(yieldOperand)); else if (isa(result.getType())) llvm::append_range(resTuple, state.getPayloadValues(yieldOperand)); else if (isa(result.getType())) llvm::append_range(resTuple, state.getParams(yieldOperand)); else assert(false && "unhandled handle type"); } // Associate the accumulated result payloads to the op's actual results. for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults)) results.setMappedValues(llvm::cast(result), resPayload); return DiagnosedSilenceableFailure::success(); } void transform::ForeachOp::getEffects( SmallVectorImpl &effects) { // NB: this `zip` should be `zip_equal` - while this op's verifier catches // arity errors, this method might get called before/in absence of `verify()`. for (auto &&[target, blockArg] : llvm::zip(getTargetsMutable(), getBody().front().getArguments())) { BlockArgument blockArgument = blockArg; if (any_of(getBody().front().without_terminator(), [&](Operation &op) { return isHandleConsumed(blockArgument, cast(&op)); })) { consumesHandle(target, effects); } else { onlyReadsHandle(target, effects); } } if (any_of(getBody().front().without_terminator(), [&](Operation &op) { return doesModifyPayload(cast(&op)); })) { modifiesPayload(effects); } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) { return doesReadPayload(cast(&op)); })) { onlyReadsPayload(effects); } producesHandle(getOperation()->getOpResults(), effects); } void transform::ForeachOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { Region *bodyRegion = &getBody(); if (point.isParent()) { regions.emplace_back(bodyRegion, bodyRegion->getArguments()); return; } // Branch back to the region or the parent. assert(point == getBody() && "unexpected region index"); regions.emplace_back(bodyRegion, bodyRegion->getArguments()); regions.emplace_back(); } OperandRange transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { // Each block argument handle is mapped to a subset (one op to be precise) // of the payload of the corresponding `targets` operand of ForeachOp. assert(point == getBody() && "unexpected region index"); return getOperation()->getOperands(); } transform::YieldOp transform::ForeachOp::getYieldOp() { return cast(getBody().front().getTerminator()); } LogicalResult transform::ForeachOp::verify() { for (auto [targetOpt, bodyArgOpt] : llvm::zip_longest(getTargets(), getBody().front().getArguments())) { if (!targetOpt || !bodyArgOpt) return emitOpError() << "expects the same number of targets as the body " "has block arguments"; if (targetOpt.value().getType() != bodyArgOpt.value().getType()) return emitOpError( "expects co-indexed targets and the body's " "block arguments to have the same op/value/param type"); } for (auto [resultOpt, yieldOperandOpt] : llvm::zip_longest(getResults(), getYieldOp().getOperands())) { if (!resultOpt || !yieldOperandOpt) return emitOpError() << "expects the same number of results as the " "yield terminator has operands"; if (resultOpt.value().getType() != yieldOperandOpt.value().getType()) return emitOpError("expects co-indexed results and yield " "operands to have the same op/value/param type"); } return success(); } //===----------------------------------------------------------------------===// // GetParentOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetParentOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector parents; DenseSet resultSet; for (Operation *target : state.getPayloadOps(getTarget())) { Operation *parent = target; for (int64_t i = 0, e = getNthParent(); i < e; ++i) { parent = parent->getParentOp(); while (parent) { bool checkIsolatedFromAbove = !getIsolatedFromAbove() || parent->hasTrait(); bool checkOpName = !getOpName().has_value() || parent->getName().getStringRef() == *getOpName(); if (checkIsolatedFromAbove && checkOpName) break; parent = parent->getParentOp(); } if (!parent) { if (getAllowEmptyResults()) { results.set(llvm::cast(getResult()), parents); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure diag = emitSilenceableError() << "could not find a parent op that matches all requirements"; diag.attachNote(target->getLoc()) << "target op"; return diag; } } if (getDeduplicate()) { if (resultSet.insert(parent).second) parents.push_back(parent); } else { parents.push_back(parent); } } results.set(llvm::cast(getResult()), parents); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // GetConsumersOfResult //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { int64_t resultNumber = getResultNumber(); auto payloadOps = state.getPayloadOps(getTarget()); if (std::empty(payloadOps)) { results.set(cast(getResult()), {}); return DiagnosedSilenceableFailure::success(); } if (!llvm::hasSingleElement(payloadOps)) return emitDefiniteFailure() << "handle must be mapped to exactly one payload op"; Operation *target = *payloadOps.begin(); if (target->getNumResults() <= resultNumber) return emitDefiniteFailure() << "result number overflow"; results.set(llvm::cast(getResult()), llvm::to_vector(target->getResult(resultNumber).getUsers())); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // GetDefiningOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector definingOps; for (Value v : state.getPayloadValues(getTarget())) { if (llvm::isa(v)) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "cannot get defining op of block argument"; diag.attachNote(v.getLoc()) << "target value"; return diag; } definingOps.push_back(v.getDefiningOp()); } results.set(llvm::cast(getResult()), definingOps); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // GetProducerOfOperand //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { int64_t operandNumber = getOperandNumber(); SmallVector producers; for (Operation *target : state.getPayloadOps(getTarget())) { Operation *producer = target->getNumOperands() <= operandNumber ? nullptr : target->getOperand(operandNumber).getDefiningOp(); if (!producer) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "could not find a producer for operand number: " << operandNumber << " of " << *target; diag.attachNote(target->getLoc()) << "target op"; return diag; } producers.push_back(producer); } results.set(llvm::cast(getResult()), producers); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // GetOperandOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetOperandOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector operands; for (Operation *target : state.getPayloadOps(getTarget())) { SmallVector operandPositions; DiagnosedSilenceableFailure diag = expandTargetSpecification( getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), target->getNumOperands(), operandPositions); if (diag.isSilenceableFailure()) { diag.attachNote(target->getLoc()) << "while considering positions of this payload operation"; return diag; } llvm::append_range(operands, llvm::map_range(operandPositions, [&](int64_t pos) { return target->getOperand(pos); })); } results.setValues(cast(getResult()), operands); return DiagnosedSilenceableFailure::success(); } LogicalResult transform::GetOperandOp::verify() { return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), getIsInverted(), getIsAll()); } //===----------------------------------------------------------------------===// // GetResultOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetResultOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector opResults; for (Operation *target : state.getPayloadOps(getTarget())) { SmallVector resultPositions; DiagnosedSilenceableFailure diag = expandTargetSpecification( getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), target->getNumResults(), resultPositions); if (diag.isSilenceableFailure()) { diag.attachNote(target->getLoc()) << "while considering positions of this payload operation"; return diag; } llvm::append_range(opResults, llvm::map_range(resultPositions, [&](int64_t pos) { return target->getResult(pos); })); } results.setValues(cast(getResult()), opResults); return DiagnosedSilenceableFailure::success(); } LogicalResult transform::GetResultOp::verify() { return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), getIsInverted(), getIsAll()); } //===----------------------------------------------------------------------===// // GetTypeOp //===----------------------------------------------------------------------===// void transform::GetTypeOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getValueMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); onlyReadsPayload(effects); } DiagnosedSilenceableFailure transform::GetTypeOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector params; for (Value value : state.getPayloadValues(getValue())) { Type type = value.getType(); if (getElemental()) { if (auto shaped = dyn_cast(type)) { type = shaped.getElementType(); } } params.push_back(TypeAttr::get(type)); } results.setParams(cast(getResult()), params); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // IncludeOp //===----------------------------------------------------------------------===// /// Applies the transform ops contained in `block`. Maps `results` to the same /// values as the operands of the block terminator. static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results) { // Apply the sequenced ops one by one. for (Operation &transform : block.without_terminator()) { DiagnosedSilenceableFailure result = state.applyTransform(cast(transform)); if (result.isDefiniteFailure()) return result; if (result.isSilenceableFailure()) { if (mode == transform::FailurePropagationMode::Propagate) { // Propagate empty results in case of early exit. forwardEmptyOperands(&block, state, results); return result; } (void)result.silence(); } } // Forward the operation mapping for values yielded from the sequence to the // values produced by the sequence op. transform::detail::forwardTerminatorOperands(&block, state, results); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::IncludeOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto callee = SymbolTable::lookupNearestSymbolFrom( getOperation(), getTarget()); assert(callee && "unverified reference to unknown symbol"); if (callee.isExternal()) return emitDefiniteFailure() << "unresolved external named sequence"; // Map operands to block arguments. SmallVector> mappings; detail::prepareValueMappings(mappings, getOperands(), state); auto scope = state.make_region_scope(callee.getBody()); for (auto &&[arg, map] : llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) { if (failed(state.mapBlockArgument(arg, map))) return DiagnosedSilenceableFailure::definiteFailure(); } DiagnosedSilenceableFailure result = applySequenceBlock( callee.getBody().front(), getFailurePropagationMode(), state, results); mappings.clear(); detail::prepareValueMappings( mappings, callee.getBody().front().getTerminator()->getOperands(), state); for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings)) results.setMappedValues(result, mapping); return result; } static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings); void transform::IncludeOp::getEffects( SmallVectorImpl &effects) { // Always mark as modifying the payload. // TODO: a mechanism to annotate effects on payload. Even when all handles are // only read, the payload may still be modified, so we currently stay on the // conservative side and always indicate modification. This may prevent some // code reordering. modifiesPayload(effects); // Results are always produced. producesHandle(getOperation()->getOpResults(), effects); // Adds default effects to operands and results. This will be added if // preconditions fail so the trait verifier doesn't complain about missing // effects and the real precondition failure is reported later on. auto defaultEffects = [&] { onlyReadsHandle(getOperation()->getOpOperands(), effects); }; // Bail if the callee is unknown. This may run as part of the verification // process before we verified the validity of the callee or of this op. auto target = getOperation()->getAttrOfType(getTargetAttrName()); if (!target) return defaultEffects(); auto callee = SymbolTable::lookupNearestSymbolFrom( getOperation(), getTarget()); if (!callee) return defaultEffects(); DiagnosedSilenceableFailure earlyVerifierResult = verifyNamedSequenceOp(callee, /*emitWarnings=*/false); if (!earlyVerifierResult.succeeded()) { (void)earlyVerifierResult.silence(); return defaultEffects(); } for (unsigned i = 0, e = getNumOperands(); i < e; ++i) { if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName)) consumesHandle(getOperation()->getOpOperand(i), effects); else onlyReadsHandle(getOperation()->getOpOperand(i), effects); } } LogicalResult transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Access through indirection and do additional checking because this may be // running before the main op verifier. auto targetAttr = getOperation()->getAttrOfType("target"); if (!targetAttr) return emitOpError() << "expects a 'target' symbol reference attribute"; auto target = symbolTable.lookupNearestSymbolFrom( *this, targetAttr); if (!target) return emitOpError() << "does not reference a named transform sequence"; FunctionType fnType = target.getFunctionType(); if (fnType.getNumInputs() != getNumOperands()) return emitError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { if (getOperand(i).getType() != fnType.getInput(i)) { return emitOpError("operand type mismatch: expected operand type ") << fnType.getInput(i) << ", but provided " << getOperand(i).getType() << " for operand number " << i; } } if (fnType.getNumResults() != getNumResults()) return emitError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { Type resultType = getResult(i).getType(); Type funcType = fnType.getResult(i); if (!implementSameTransformInterface(resultType, funcType)) { return emitOpError() << "type of result #" << i << " must implement the same transform dialect " "interface as the corresponding callee result"; } } return verifyFunctionLikeConsumeAnnotations( cast(*target), /*emitWarnings=*/false, /*alsoVerifyInternal=*/true) .checkAndReport(); } //===----------------------------------------------------------------------===// // MatchOperationEmptyOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation( ::std::optional<::mlir::Operation *> maybeCurrent, transform::TransformResults &results, transform::TransformState &state) { if (!maybeCurrent.has_value()) { DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; }); return DiagnosedSilenceableFailure::success(); } DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; }); return emitSilenceableError() << "operation is not empty"; } //===----------------------------------------------------------------------===// // MatchOperationNameOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation( Operation *current, transform::TransformResults &results, transform::TransformState &state) { StringRef currentOpName = current->getName().getStringRef(); for (auto acceptedAttr : getOpNames().getAsRange()) { if (acceptedAttr.getValue() == currentOpName) return DiagnosedSilenceableFailure::success(); } return emitSilenceableError() << "wrong operation name"; } //===----------------------------------------------------------------------===// // MatchParamCmpIOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto signedAPIntAsString = [&](const APInt &value) { std::string str; llvm::raw_string_ostream os(str); value.print(os, /*isSigned=*/true); return str; }; ArrayRef params = state.getParams(getParam()); ArrayRef references = state.getParams(getReference()); if (params.size() != references.size()) { return emitSilenceableError() << "parameters have different payload lengths (" << params.size() << " vs " << references.size() << ")"; } for (auto &&[i, param, reference] : llvm::enumerate(params, references)) { auto intAttr = llvm::dyn_cast(param); auto refAttr = llvm::dyn_cast(reference); if (!intAttr || !refAttr) { return emitDefiniteFailure() << "non-integer parameter value not expected"; } if (intAttr.getType() != refAttr.getType()) { return emitDefiniteFailure() << "mismatching integer attribute types in parameter #" << i; } APInt value = intAttr.getValue(); APInt refValue = refAttr.getValue(); // TODO: this copy will not be necessary in C++20. int64_t position = i; auto reportError = [&](StringRef direction) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected parameter to be " << direction << " " << signedAPIntAsString(refValue) << ", got " << signedAPIntAsString(value); diag.attachNote(getParam().getLoc()) << "value # " << position << " associated with the parameter defined here"; return diag; }; switch (getPredicate()) { case MatchCmpIPredicate::eq: if (value.eq(refValue)) break; return reportError("equal to"); case MatchCmpIPredicate::ne: if (value.ne(refValue)) break; return reportError("not equal to"); case MatchCmpIPredicate::lt: if (value.slt(refValue)) break; return reportError("less than"); case MatchCmpIPredicate::le: if (value.sle(refValue)) break; return reportError("less than or equal to"); case MatchCmpIPredicate::gt: if (value.sgt(refValue)) break; return reportError("greater than"); case MatchCmpIPredicate::ge: if (value.sge(refValue)) break; return reportError("greater than or equal to"); } } return DiagnosedSilenceableFailure::success(); } void transform::MatchParamCmpIOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getParamMutable(), effects); onlyReadsHandle(getReferenceMutable(), effects); } //===----------------------------------------------------------------------===// // ParamConstantOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.setParams(cast(getParam()), {getValue()}); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // MergeHandlesOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { ValueRange handles = getHandles(); if (isa(handles.front().getType())) { SmallVector operations; for (Value operand : handles) llvm::append_range(operations, state.getPayloadOps(operand)); if (!getDeduplicate()) { results.set(llvm::cast(getResult()), operations); return DiagnosedSilenceableFailure::success(); } SetVector uniqued(operations.begin(), operations.end()); results.set(llvm::cast(getResult()), uniqued.getArrayRef()); return DiagnosedSilenceableFailure::success(); } if (llvm::isa(handles.front().getType())) { SmallVector attrs; for (Value attribute : handles) llvm::append_range(attrs, state.getParams(attribute)); if (!getDeduplicate()) { results.setParams(cast(getResult()), attrs); return DiagnosedSilenceableFailure::success(); } SetVector uniqued(attrs.begin(), attrs.end()); results.setParams(cast(getResult()), uniqued.getArrayRef()); return DiagnosedSilenceableFailure::success(); } assert( llvm::isa(handles.front().getType()) && "expected value handle type"); SmallVector payloadValues; for (Value value : handles) llvm::append_range(payloadValues, state.getPayloadValues(value)); if (!getDeduplicate()) { results.setValues(cast(getResult()), payloadValues); return DiagnosedSilenceableFailure::success(); } SetVector uniqued(payloadValues.begin(), payloadValues.end()); results.setValues(cast(getResult()), uniqued.getArrayRef()); return DiagnosedSilenceableFailure::success(); } bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() { // Handles may be the same if deduplicating is enabled. return getDeduplicate(); } void transform::MergeHandlesOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getHandlesMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); // There are no effects on the Payload IR as this is only a handle // manipulation. } OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) { if (getDeduplicate() || getHandles().size() != 1) return {}; // If deduplication is not required and there is only one operand, it can be // used directly instead of merging. return getHandles().front(); } //===----------------------------------------------------------------------===// // NamedSequenceOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { if (isExternal()) return emitDefiniteFailure() << "unresolved external named sequence"; // Map the entry block argument to the list of operations. // Note: this is the same implementation as PossibleTopLevelTransformOp but // without attaching the interface / trait since that is tailored to a // dangling top-level op that does not get "called". auto scope = state.make_region_scope(getBody()); if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments( state, this->getOperation(), getBody()))) return DiagnosedSilenceableFailure::definiteFailure(); return applySequenceBlock(getBody().front(), FailurePropagationMode::Propagate, state, results); } void transform::NamedSequenceOp::getEffects( SmallVectorImpl &effects) {} ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser, OperationState &result) { return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), [](Builder &builder, ArrayRef inputs, ArrayRef results, function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(inputs, results); }, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void transform::NamedSequenceOp::print(OpAsmPrinter &printer) { function_interface_impl::printFunctionOp( printer, cast(getOperation()), /*isVariadic=*/false, getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(), getResAttrsAttrName()); } /// Verifies that a symbol function-like transform dialect operation has the /// signature and the terminator that have conforming types, i.e., types /// implementing the same transform dialect type interface. If `allowExternal` /// is set, allow external symbols (declarations) and don't check the terminator /// as it may not exist. static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) { if (auto parent = op->getParentOfType()) { DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) << "cannot be defined inside another transform op"; diag.attachNote(parent.getLoc()) << "ancestor transform op"; return diag; } if (op.isExternal() || op.getFunctionBody().empty()) { if (allowExternal) return DiagnosedSilenceableFailure::success(); return emitSilenceableFailure(op) << "cannot be external"; } if (op.getFunctionBody().front().empty()) return emitSilenceableFailure(op) << "expected a non-empty body block"; Operation *terminator = &op.getFunctionBody().front().back(); if (!isa(terminator)) { DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) << "expected '" << transform::YieldOp::getOperationName() << "' as terminator"; diag.attachNote(terminator->getLoc()) << "terminator"; return diag; } if (terminator->getNumOperands() != op.getResultTypes().size()) { return emitSilenceableFailure(terminator) << "expected terminator to have as many operands as the parent op " "has results"; } for (auto [i, operandType, resultType] : llvm::zip_equal( llvm::seq(0, terminator->getNumOperands()), terminator->getOperands().getType(), op.getResultTypes())) { if (operandType == resultType) continue; return emitSilenceableFailure(terminator) << "the type of the terminator operand #" << i << " must match the type of the corresponding parent op result (" << operandType << " vs " << resultType << ")"; } return DiagnosedSilenceableFailure::success(); } /// Verification of a NamedSequenceOp. This does not report the error /// immediately, so it can be used to check for op's well-formedness before the /// verifier runs, e.g., during trait verification. static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) { if (Operation *parent = op->getParentWithTrait()) { if (!parent->getAttr( transform::TransformDialect::kWithNamedSequenceAttrName)) { DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) << "expects the parent symbol table to have the '" << transform::TransformDialect::kWithNamedSequenceAttrName << "' attribute"; diag.attachNote(parent->getLoc()) << "symbol table operation"; return diag; } } if (auto parent = op->getParentOfType()) { DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) << "cannot be defined inside another transform op"; diag.attachNote(parent.getLoc()) << "ancestor transform op"; return diag; } if (op.isExternal() || op.getBody().empty()) return verifyFunctionLikeConsumeAnnotations(cast(*op), emitWarnings); if (op.getBody().front().empty()) return emitSilenceableFailure(op) << "expected a non-empty body block"; Operation *terminator = &op.getBody().front().back(); if (!isa(terminator)) { DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) << "expected '" << transform::YieldOp::getOperationName() << "' as terminator"; diag.attachNote(terminator->getLoc()) << "terminator"; return diag; } if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) { return emitSilenceableFailure(terminator) << "expected terminator to have as many operands as the parent op " "has results"; } for (auto [i, operandType, resultType] : llvm::zip_equal(llvm::seq(0, terminator->getNumOperands()), terminator->getOperands().getType(), op.getFunctionType().getResults())) { if (operandType == resultType) continue; return emitSilenceableFailure(terminator) << "the type of the terminator operand #" << i << " must match the type of the corresponding parent op result (" << operandType << " vs " << resultType << ")"; } auto funcOp = cast(*op); DiagnosedSilenceableFailure diag = verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings); if (!diag.succeeded()) return diag; return verifyYieldingSingleBlockOp(funcOp, /*allowExternal=*/true); } LogicalResult transform::NamedSequenceOp::verify() { // Actual verification happens in a separate function for reusability. return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport(); } template static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder) { SmallVector types; types.reserve(1 + extraBindingTypes.size()); types.push_back(bbArgType); llvm::append_range(types, extraBindingTypes); OpBuilder::InsertionGuard guard(builder); Region *region = state.regions.back().get(); Block *bodyBlock = builder.createBlock(region, region->begin(), types, SmallVector(types.size(), state.location)); // Populate body. builder.setInsertionPointToStart(bodyBlock); if constexpr (llvm::function_traits::num_args == 3) { bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); } else { bodyBuilder(builder, state.location, bodyBlock->getArgument(0), bodyBlock->getArguments().drop_front()); } } void transform::NamedSequenceOp::build(OpBuilder &builder, OperationState &state, StringRef symName, Type rootType, TypeRange resultTypes, SequenceBodyBuilderFn bodyBuilder, ArrayRef attrs, ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(symName)); state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(FunctionType::get(builder.getContext(), rootType, resultTypes))); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); buildSequenceBody(builder, state, rootType, /*extraBindingTypes=*/TypeRange(), bodyBuilder); } //===----------------------------------------------------------------------===// // NumAssociationsOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { size_t numAssociations = llvm::TypeSwitch(getHandle().getType()) .Case([&](TransformHandleTypeInterface opHandle) { return llvm::range_size(state.getPayloadOps(getHandle())); }) .Case([&](TransformValueHandleTypeInterface valueHandle) { return llvm::range_size(state.getPayloadValues(getHandle())); }) .Case([&](TransformParamTypeInterface param) { return llvm::range_size(state.getParams(getHandle())); }) .Default([](Type) { llvm_unreachable("unknown kind of transform dialect type"); return 0; }); results.setParams(cast(getNum()), rewriter.getI64IntegerAttr(numAssociations)); return DiagnosedSilenceableFailure::success(); } LogicalResult transform::NumAssociationsOp::verify() { // Verify that the result type accepts an i64 attribute as payload. auto resultType = cast(getNum().getType()); return resultType .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)}) .checkAndReport(); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::SelectOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector result; auto payloadOps = state.getPayloadOps(getTarget()); for (Operation *op : payloadOps) { if (op->getName().getStringRef() == getOpName()) result.push_back(op); } results.set(cast(getResult()), result); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // SplitHandleOp //===----------------------------------------------------------------------===// void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result, Value target, int64_t numResultHandles) { result.addOperands(target); result.addTypes(SmallVector(numResultHandles, target.getType())); } DiagnosedSilenceableFailure transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { int64_t numPayloads = llvm::TypeSwitch(getHandle().getType()) .Case([&](auto x) { return llvm::range_size(state.getPayloadOps(getHandle())); }) .Case([&](auto x) { return llvm::range_size(state.getPayloadValues(getHandle())); }) .Case([&](auto x) { return llvm::range_size(state.getParams(getHandle())); }) .Default([](auto x) { llvm_unreachable("unknown transform dialect type interface"); return -1; }); auto produceNumOpsError = [&]() { return emitSilenceableError() << getHandle() << " expected to contain " << this->getNumResults() << " payloads but it contains " << numPayloads << " payloads"; }; // Fail if there are more payload ops than results and no overflow result was // specified. if (numPayloads > getNumResults() && !getOverflowResult().has_value()) return produceNumOpsError(); // Fail if there are more results than payload ops. Unless: // - "fail_on_payload_too_small" is set to "false", or // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops. if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() && (numPayloads != 0 || !getPassThroughEmptyHandle())) return produceNumOpsError(); // Distribute payloads. SmallVector> resultHandles(getNumResults(), {}); if (getOverflowResult()) resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults()); auto container = [&]() { if (isa(getHandle().getType())) { return llvm::map_to_vector( state.getPayloadOps(getHandle()), [](Operation *op) -> MappedValue { return op; }); } if (isa(getHandle().getType())) { return llvm::map_to_vector(state.getPayloadValues(getHandle()), [](Value v) -> MappedValue { return v; }); } assert(isa(getHandle().getType()) && "unsupported kind of transform dialect type"); return llvm::map_to_vector(state.getParams(getHandle()), [](Attribute a) -> MappedValue { return a; }); }(); for (auto &&en : llvm::enumerate(container)) { int64_t resultNum = en.index(); if (resultNum >= getNumResults()) resultNum = *getOverflowResult(); resultHandles[resultNum].push_back(en.value()); } // Set transform op results. for (auto &&it : llvm::enumerate(resultHandles)) results.setMappedValues(llvm::cast(getResult(it.index())), it.value()); return DiagnosedSilenceableFailure::success(); } void transform::SplitHandleOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getHandleMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); // There are no effects on the Payload IR as this is only a handle // manipulation. } LogicalResult transform::SplitHandleOp::verify() { if (getOverflowResult().has_value() && !(*getOverflowResult() < getNumResults())) return emitOpError("overflow_result is not a valid result index"); for (Type resultType : getResultTypes()) { if (implementSameTransformInterface(getHandle().getType(), resultType)) continue; return emitOpError("expects result types to implement the same transform " "interface as the operand type"); } return success(); } //===----------------------------------------------------------------------===// // ReplicateOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ReplicateOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); for (const auto &en : llvm::enumerate(getHandles())) { Value handle = en.value(); if (isa(handle.getType())) { SmallVector current = llvm::to_vector(state.getPayloadOps(handle)); SmallVector payload; payload.reserve(numRepetitions * current.size()); for (unsigned i = 0; i < numRepetitions; ++i) llvm::append_range(payload, current); results.set(llvm::cast(getReplicated()[en.index()]), payload); } else { assert(llvm::isa(handle.getType()) && "expected param type"); ArrayRef current = state.getParams(handle); SmallVector params; params.reserve(numRepetitions * current.size()); for (unsigned i = 0; i < numRepetitions; ++i) llvm::append_range(params, current); results.setParams(llvm::cast(getReplicated()[en.index()]), params); } } return DiagnosedSilenceableFailure::success(); } void transform::ReplicateOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getPatternMutable(), effects); onlyReadsHandle(getHandlesMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); } //===----------------------------------------------------------------------===// // SequenceOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::SequenceOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); if (failed(mapBlockArguments(state))) return DiagnosedSilenceableFailure::definiteFailure(); return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state, results); } static ParseResult parseSequenceOpOperands( OpAsmParser &parser, std::optional &root, Type &rootType, SmallVectorImpl &extraBindings, SmallVectorImpl &extraBindingTypes) { OpAsmParser::UnresolvedOperand rootOperand; OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand); if (!hasRoot.has_value()) { root = std::nullopt; return success(); } if (failed(hasRoot.value())) return failure(); root = rootOperand; if (succeeded(parser.parseOptionalComma())) { if (failed(parser.parseOperandList(extraBindings))) return failure(); } if (failed(parser.parseColon())) return failure(); // The paren is truly optional. (void)parser.parseOptionalLParen(); if (failed(parser.parseType(rootType))) { return failure(); } if (!extraBindings.empty()) { if (parser.parseComma() || parser.parseTypeList(extraBindingTypes)) return failure(); } if (extraBindingTypes.size() != extraBindings.size()) { return parser.emitError(parser.getNameLoc(), "expected types to be provided for all operands"); } // The paren is truly optional. (void)parser.parseOptionalRParen(); return success(); } static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes) { if (!root) return; printer << root; bool hasExtras = !extraBindings.empty(); if (hasExtras) { printer << ", "; printer.printOperands(extraBindings); } printer << " : "; if (hasExtras) printer << "("; printer << rootType; if (hasExtras) { printer << ", "; llvm::interleaveComma(extraBindingTypes, printer.getStream()); printer << ")"; } } /// Returns `true` if the given op operand may be consuming the handle value in /// the Transform IR. That is, if it may have a Free effect on it. static bool isValueUsePotentialConsumer(OpOperand &use) { // Conservatively assume the effect being present in absence of the interface. auto iface = dyn_cast(use.getOwner()); if (!iface) return true; return isHandleConsumed(use.get(), iface); } LogicalResult checkDoubleConsume(Value value, function_ref reportError) { OpOperand *potentialConsumer = nullptr; for (OpOperand &use : value.getUses()) { if (!isValueUsePotentialConsumer(use)) continue; if (!potentialConsumer) { potentialConsumer = &use; continue; } InFlightDiagnostic diag = reportError() << " has more than one potential consumer"; diag.attachNote(potentialConsumer->getOwner()->getLoc()) << "used here as operand #" << potentialConsumer->getOperandNumber(); diag.attachNote(use.getOwner()->getLoc()) << "used here as operand #" << use.getOperandNumber(); return diag; } return success(); } LogicalResult transform::SequenceOp::verify() { assert(getBodyBlock()->getNumArguments() >= 1 && "the number of arguments must have been verified to be more than 1 by " "PossibleTopLevelTransformOpTrait"); if (!getRoot() && !getExtraBindings().empty()) { return emitOpError() << "does not expect extra operands when used as top-level"; } // Check if a block argument has more than one consuming use. for (BlockArgument arg : getBodyBlock()->getArguments()) { if (failed(checkDoubleConsume(arg, [this, arg]() { return (emitOpError() << "block argument #" << arg.getArgNumber()); }))) { return failure(); } } // Check properties of the nested operations they cannot check themselves. for (Operation &child : *getBodyBlock()) { if (!isa(child) && &child != &getBodyBlock()->back()) { InFlightDiagnostic diag = emitOpError() << "expected children ops to implement TransformOpInterface"; diag.attachNote(child.getLoc()) << "op without interface"; return diag; } for (OpResult result : child.getResults()) { auto report = [&]() { return (child.emitError() << "result #" << result.getResultNumber()); }; if (failed(checkDoubleConsume(result, report))) return failure(); } } if (!getBodyBlock()->mightHaveTerminator()) return emitOpError() << "expects to have a terminator in the body"; if (getBodyBlock()->getTerminator()->getOperandTypes() != getOperation()->getResultTypes()) { InFlightDiagnostic diag = emitOpError() << "expects the types of the terminator operands " "to match the types of the result"; diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; return diag; } return success(); } void transform::SequenceOp::getEffects( SmallVectorImpl &effects) { getPotentialTopLevelEffects(effects); } OperandRange transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { assert(point == getBody() && "unexpected region index"); if (getOperation()->getNumOperands() > 0) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); } void transform::SequenceOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { if (point.isParent()) { Region *bodyRegion = &getBody(); regions.emplace_back(bodyRegion, getNumOperands() != 0 ? bodyRegion->getArguments() : Block::BlockArgListType()); return; } assert(point == getBody() && "unexpected region index"); regions.emplace_back(getOperation()->getResults()); } void transform::SequenceOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &bounds) { (void)operands; bounds.emplace_back(1, 1); } void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, TypeRange resultTypes, FailurePropagationMode failurePropagationMode, Value root, SequenceBodyBuilderFn bodyBuilder) { build(builder, state, resultTypes, failurePropagationMode, root, /*extra_bindings=*/ValueRange()); Type bbArgType = root.getType(); buildSequenceBody(builder, state, bbArgType, /*extraBindingTypes=*/TypeRange(), bodyBuilder); } void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, TypeRange resultTypes, FailurePropagationMode failurePropagationMode, Value root, ValueRange extraBindings, SequenceBodyBuilderArgsFn bodyBuilder) { build(builder, state, resultTypes, failurePropagationMode, root, extraBindings); buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(), bodyBuilder); } void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, TypeRange resultTypes, FailurePropagationMode failurePropagationMode, Type bbArgType, SequenceBodyBuilderFn bodyBuilder) { build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), /*extra_bindings=*/ValueRange()); buildSequenceBody(builder, state, bbArgType, /*extraBindingTypes=*/TypeRange(), bodyBuilder); } void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, TypeRange resultTypes, FailurePropagationMode failurePropagationMode, Type bbArgType, TypeRange extraBindingTypes, SequenceBodyBuilderArgsFn bodyBuilder) { build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), /*extra_bindings=*/ValueRange()); buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder); } //===----------------------------------------------------------------------===// // PrintOp //===----------------------------------------------------------------------===// void transform::PrintOp::build(OpBuilder &builder, OperationState &result, StringRef name) { if (!name.empty()) result.getOrAddProperties().name = builder.getStringAttr(name); } void transform::PrintOp::build(OpBuilder &builder, OperationState &result, Value target, StringRef name) { result.addOperands({target}); build(builder, result, name); } DiagnosedSilenceableFailure transform::PrintOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { llvm::outs() << "[[[ IR printer: "; if (getName().has_value()) llvm::outs() << *getName() << " "; OpPrintingFlags printFlags; if (getAssumeVerified().value_or(false)) printFlags.assumeVerified(); if (getUseLocalScope().value_or(false)) printFlags.useLocalScope(); if (getSkipRegions().value_or(false)) printFlags.skipRegions(); if (!getTarget()) { llvm::outs() << "top-level ]]]\n"; state.getTopLevel()->print(llvm::outs(), printFlags); llvm::outs() << "\n"; llvm::outs().flush(); return DiagnosedSilenceableFailure::success(); } llvm::outs() << "]]]\n"; for (Operation *target : state.getPayloadOps(getTarget())) { target->print(llvm::outs(), printFlags); llvm::outs() << "\n"; } llvm::outs().flush(); return DiagnosedSilenceableFailure::success(); } void transform::PrintOp::getEffects( SmallVectorImpl &effects) { // We don't really care about mutability here, but `getTarget` now // unconditionally casts to a specific type before verification could run // here. if (!getTargetMutable().empty()) onlyReadsHandle(getTargetMutable()[0], effects); onlyReadsPayload(effects); // There is no resource for stderr file descriptor, so just declare print // writes into the default resource. effects.emplace_back(MemoryEffects::Write::get()); } //===----------------------------------------------------------------------===// // VerifyOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (failed(::mlir::verify(target))) { DiagnosedDefiniteFailure diag = emitDefiniteFailure() << "failed to verify payload op"; diag.attachNote(target->getLoc()) << "payload op"; return diag; } return DiagnosedSilenceableFailure::success(); } void transform::VerifyOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTargetMutable(), effects); } //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// void transform::YieldOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getOperandsMutable(), effects); }