//===- TestTransformDialectExtension.cpp ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines an extension of the MLIR Transform dialect for testing // purposes. // //===----------------------------------------------------------------------===// #include "TestTransformDialectExtension.h" #include "TestTransformStateExtension.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; namespace { /// Simple transform op defined outside of the dialect. Just emits a remark when /// applied. This op is defined in C++ to test that C++ definitions also work /// for op injection into the Transform dialect. class TestTransformOp : public Op { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) using Op::Op; static ArrayRef getAttributeNames() { return {}; } static constexpr llvm::StringLiteral getOperationName() { return llvm::StringLiteral("transform.test_transform_op"); } DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { InFlightDiagnostic remark = emitRemark() << "applying transformation"; if (Attribute message = getMessage()) remark << " " << message; return DiagnosedSilenceableFailure::success(); } Attribute getMessage() { return getOperation()->getDiscardableAttr("message"); } static ParseResult parse(OpAsmParser &parser, OperationState &state) { StringAttr message; OptionalParseResult result = parser.parseOptionalAttribute(message); if (!result.has_value()) return success(); if (result.value().succeeded()) state.addAttribute("message", message); return result.value(); } void print(OpAsmPrinter &printer) { if (getMessage()) printer << " " << getMessage(); } // No side effects. void getEffects(SmallVectorImpl &effects) {} }; /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait /// in cases where it is attached to ops that do not comply with the trait /// requirements. This op cannot be defined in ODS because ODS generates strict /// verifiers that overalp with those in the trait and run earlier. class TestTransformUnrestrictedOpNoInterface : public Op { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestTransformUnrestrictedOpNoInterface) using Op::Op; static ArrayRef getAttributeNames() { return {}; } static constexpr llvm::StringLiteral getOperationName() { return llvm::StringLiteral( "transform.test_transform_unrestricted_op_no_interface"); } DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } // No side effects. void getEffects(SmallVectorImpl &effects) {} }; } // namespace DiagnosedSilenceableFailure mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { results.set(cast(getResult()), {getOperation()->getOperand(0).getDefiningOp()}); } else { results.set(cast(getResult()), {getOperation()}); } return DiagnosedSilenceableFailure::success(); } void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects( SmallVectorImpl &effects) { if (getOperand()) transform::onlyReadsHandle(getOperandMutable(), effects); transform::producesHandle(getOperation()->getOpResults(), effects); } DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToSelfOperand::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.setValues(llvm::cast(getOut()), {getIn()}); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestProduceValueHandleToSelfOperand::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getInMutable(), effects); transform::producesHandle(getOperation()->getOpResults(), effects); transform::onlyReadsPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToResult::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->getNumResults() <= getNumber()) return emitSilenceableError() << "payload has no result #" << getNumber(); results.push_back(target->getResult(getNumber())); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestProduceValueHandleToResult::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getInMutable(), effects); transform::producesHandle(getOperation()->getOpResults(), effects); transform::onlyReadsPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->getBlock()) return emitSilenceableError() << "payload has no parent block"; if (target->getBlock()->getNumArguments() <= getNumber()) return emitSilenceableError() << "parent of the payload has no argument #" << getNumber(); results.push_back(target->getBlock()->getArgument(getNumber())); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getInMutable(), effects); transform::producesHandle(getOperation()->getOpResults(), effects); transform::onlyReadsPayload(effects); } bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() { return getAllowRepeatedHandles(); } DiagnosedSilenceableFailure mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } void mlir::test::TestConsumeOperand::getEffects( SmallVectorImpl &effects) { transform::consumesHandle(getOperation()->getOpOperands(), effects); if (getSecondOperand()) transform::consumesHandle(getSecondOperandMutable(), effects); transform::modifiesPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payload = state.getPayloadOps(getOperand()); assert(llvm::hasSingleElement(payload) && "expected a single target op"); if ((*payload.begin())->getName().getStringRef() != getOpKind()) { return emitSilenceableError() << "op expected the operand to be associated a payload op of kind " << getOpKind() << " got " << (*payload.begin())->getName().getStringRef(); } emitRemark() << "succeeded"; return DiagnosedSilenceableFailure::success(); } void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects( SmallVectorImpl &effects) { transform::consumesHandle(getOperation()->getOpOperands(), effects); transform::modifiesPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestSucceedIfOperandOfOpKind::matchOperation( Operation *op, transform::TransformResults &results, transform::TransformState &state) { if (op->getName().getStringRef() != getOpKind()) { return emitSilenceableError() << "op expected the operand to be associated with a payload op of " "kind " << getOpKind() << " got " << op->getName().getStringRef(); } return DiagnosedSilenceableFailure::success(); } void mlir::test::TestSucceedIfOperandOfOpKind::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); transform::onlyReadsPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { state.addExtension(getMessageAttr()); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) { emitRemark() << "extension absent"; return DiagnosedSilenceableFailure::success(); } InFlightDiagnostic diag = emitRemark() << "extension present, " << extension->getMessage(); for (Operation *payload : state.getPayloadOps(getOperand())) { diag.attachNote(payload->getLoc()) << "associated payload op"; #ifndef NDEBUG SmallVector handles; assert(succeeded(state.getHandlesForPayloadOp(payload, handles))); assert(llvm::is_contained(handles, getOperand()) && "inconsistent mapping between transform IR handles and payload IR " "operations"); #endif // NDEBUG } return DiagnosedSilenceableFailure::success(); } void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); transform::onlyReadsPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) return emitDefiniteFailure("TestTransformStateExtension missing"); if (failed(extension->updateMapping( *state.getPayloadOps(getOperand()).begin(), getOperation()))) return DiagnosedSilenceableFailure::definiteFailure(); if (getNumResults() > 0) results.set(cast(getResult(0)), {getOperation()}); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); transform::producesHandle(getOperation()->getOpResults(), effects); transform::onlyReadsPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { state.removeExtension(); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payloadOps = state.getPayloadOps(getTarget()); auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); results.set(llvm::cast(getResult()), reversedOps); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } void mlir::test::TestTransformOpWithRegions::getEffects( SmallVectorImpl &effects) {} DiagnosedSilenceableFailure mlir::test::TestBranchingTransformOpTerminator::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } void mlir::test::TestBranchingTransformOpTerminator::getEffects( SmallVectorImpl &effects) {} DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { emitRemark() << getRemark(); for (Operation *op : state.getPayloadOps(getTarget())) rewriter.eraseOp(op); if (getFailAfterErase()) return emitSilenceableError() << "silenceable error"; return DiagnosedSilenceableFailure::success(); } void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects( SmallVectorImpl &effects) { transform::consumesHandle(getTargetMutable(), effects); transform::modifiesPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { static int count = 0; if (count++ == 0) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); results.push_back(OpBuilder(target).create(opState)); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(nullptr); results.push_back(OpBuilder(target).create(opState)); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->hasAttr("target_me")) return DiagnosedSilenceableFailure::success(); return emitDefaultSilenceableFailure(target); } DiagnosedSilenceableFailure mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.set(llvm::cast(getCopy()), state.getPayloadOps(getHandle())); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestCopyPayloadOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getHandleMutable(), effects); transform::producesHandle(getOperation()->getOpResults(), effects); transform::onlyReadsPayload(effects); } DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload( Location loc, ArrayRef payload) const { if (payload.empty()) return DiagnosedSilenceableFailure::success(); for (Operation *op : payload) { if (op->getName().getDialectNamespace() != "test") { return emitSilenceableError(loc) << "expected the payload operation to " "belong to the 'test' dialect"; } } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload( Location loc, ArrayRef payload) const { for (Attribute attr : payload) { auto integerAttr = llvm::dyn_cast(attr); if (integerAttr && integerAttr.getType().isSignlessInteger(32)) continue; return emitSilenceableError(loc) << "expected the parameter to be a i32 integer attribute"; } return DiagnosedSilenceableFailure::success(); } void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTargetMutable(), effects); } DiagnosedSilenceableFailure mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { int64_t count = 0; for (Operation *op : state.getPayloadOps(getTarget())) { op->walk([&](Operation *nested) { SmallVector handles; (void)state.getHandlesForPayloadOp(nested, handles); count += handles.size(); }); } emitRemark() << count << " handles nested under"; return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector values(/*Size=*/1, /*Value=*/0); if (Value param = getParam()) { values = llvm::to_vector( llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t { return llvm::cast(attr).getValue().getLimitedValue( UINT32_MAX); })); } Builder builder(getContext()); SmallVector result = llvm::to_vector( llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute { return builder.getI32IntegerAttr(value + getAddendum()); })); results.setParams(llvm::cast(getResult()), result); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestProduceParamWithNumberOfTestOps::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { Builder builder(getContext()); SmallVector result = llvm::to_vector( llvm::map_range(state.getPayloadOps(getHandle()), [&builder](Operation *payload) -> Attribute { int32_t count = 0; payload->walk([&count](Operation *op) { if (op->getName().getDialectNamespace() == "test") ++count; }); return builder.getI32IntegerAttr(count); })); results.setParams(llvm::cast(getResult()), result); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.setParams(llvm::cast(getResult()), getAttr()); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getInMutable(), effects); transform::producesHandle(getOperation()->getOpResults(), effects); } DiagnosedSilenceableFailure mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, ::transform::ApplyToEachResultList &results, ::transform::TransformState &state) { Builder builder(getContext()); if (getFirstResultIsParam()) { results.push_back(builder.getI64IntegerAttr(0)); } else if (getFirstResultIsNull()) { results.push_back(nullptr); } else { results.push_back(*state.getPayloadOps(getIn()).begin()); } if (getSecondResultIsHandle()) { results.push_back(*state.getPayloadOps(getIn()).begin()); } else { results.push_back(builder.getI64IntegerAttr(42)); } return DiagnosedSilenceableFailure::success(); } void mlir::test::TestProduceNullPayloadOp::getEffects( SmallVectorImpl &effects) { transform::producesHandle(getOperation()->getOpResults(), effects); } DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector null({nullptr}); results.set(llvm::cast(getOut()), null); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.set(cast(getOut()), {}); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestProduceNullParamOp::getEffects( SmallVectorImpl &effects) { transform::producesHandle(getOperation()->getOpResults(), effects); } DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.setParams(llvm::cast(getOut()), Attribute()); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestProduceNullValueOp::getEffects( SmallVectorImpl &effects) { transform::producesHandle(getOperation()->getOpResults(), effects); } DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.setValues(llvm::cast(getOut()), {Value()}); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestRequiredMemoryEffectsOp::getEffects( SmallVectorImpl &effects) { if (getHasOperandEffect()) transform::consumesHandle(getInMutable(), effects); if (getHasResultEffect()) { transform::producesHandle(getOperation()->getOpResults(), effects); } else { effects.emplace_back(MemoryEffects::Read::get(), llvm::cast(getOut()), transform::TransformMappingResource::get()); } if (getModifiesPayload()) transform::modifiesPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.set(llvm::cast(getOut()), state.getPayloadOps(getIn())); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestTrackedRewriteOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getInMutable(), effects); transform::modifiesPayload(effects); } void mlir::test::TestDummyPayloadOp::getEffects( SmallVectorImpl &effects) { transform::producesHandle(getOperation()->getOpResults(), effects); } LogicalResult mlir::test::TestDummyPayloadOp::verify() { if (getFailToVerify()) return emitOpError() << "fail_to_verify is set"; return success(); } DiagnosedSilenceableFailure mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { int64_t numIterations = 0; // `getPayloadOps` returns an iterator that skips ops that are erased in the // loop body. Replacement ops are not enumerated. for (Operation *op : state.getPayloadOps(getIn())) { ++numIterations; (void)op; // Erase all payload ops. The outer loop should have only one iteration. for (Operation *op : state.getPayloadOps(getIn())) { rewriter.setInsertionPoint(op); if (op->hasAttr("erase_me")) { rewriter.eraseOp(op); continue; } if (!op->hasAttr("replace_me")) { continue; } SmallVector attributes; attributes.emplace_back(rewriter.getStringAttr("new_op"), rewriter.getUnitAttr()); OperationState opState(op->getLoc(), op->getName().getIdentifier(), /*operands=*/ValueRange(), /*types=*/op->getResultTypes(), attributes); Operation *newOp = rewriter.create(opState); rewriter.replaceOp(op, newOp->getResults()); } } emitRemark() << numIterations << " iterations"; return DiagnosedSilenceableFailure::success(); } namespace { // Test pattern to replace an operation with a new op. class ReplaceWithNewOp : public RewritePattern { public: ReplaceWithNewOp(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto newName = op->getAttrOfType("replace_with_new_op"); if (!newName) return failure(); Operation *newOp = rewriter.create( op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(), op->getOperands(), op->getResultTypes()); rewriter.replaceOp(op, newOp->getResults()); return success(); } }; // Test pattern to erase an operation. class EraseOp : public RewritePattern { public: EraseOp(MLIRContext *context) : RewritePattern("test.erase_op", /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); } }; } // namespace void mlir::test::ApplyTestPatternsOp::populatePatterns( RewritePatternSet &patterns) { patterns.insert(patterns.getContext()); } void mlir::test::TestReEnterRegionOp::getEffects( SmallVectorImpl &effects) { transform::consumesHandle(getOperation()->getOpOperands(), effects); transform::modifiesPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector> mappings; for (BlockArgument arg : getBody().front().getArguments()) { mappings.emplace_back(llvm::to_vector(llvm::map_range( state.getPayloadOps(getOperand(arg.getArgNumber())), [](Operation *op) -> transform::MappedValue { return op; }))); } for (int i = 0; i < 4; ++i) { auto scope = state.make_region_scope(getBody()); for (BlockArgument arg : getBody().front().getArguments()) { if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()]))) return DiagnosedSilenceableFailure::definiteFailure(); } for (Operation &op : getBody().front().without_terminator()) { DiagnosedSilenceableFailure diag = state.applyTransform(cast(op)); if (!diag.succeeded()) return diag; } } return DiagnosedSilenceableFailure::success(); } LogicalResult mlir::test::TestReEnterRegionOp::verify() { if (getNumOperands() != getBody().front().getNumArguments()) { return emitOpError() << "expects as many operands as block arguments"; } return success(); } DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto originalOps = state.getPayloadOps(getOriginal()); auto replacementOps = state.getPayloadOps(getReplacement()); if (llvm::range_size(originalOps) != llvm::range_size(replacementOps)) return emitSilenceableError() << "expected same number of original and " "replacement payload operations"; for (const auto &[original, replacement] : llvm::zip(originalOps, replacementOps)) { if (failed( rewriter.notifyPayloadOperationReplaced(original, replacement))) { auto diag = emitSilenceableError() << "unable to replace payload op in transform mapping"; diag.attachNote(original->getLoc()) << "original payload op"; diag.attachNote(replacement->getLoc()) << "replacement payload op"; return diag; } } return DiagnosedSilenceableFailure::success(); } void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getOriginalMutable(), effects); transform::onlyReadsHandle(getReplacementMutable(), effects); } DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Provide some IR that does not verify. rewriter.setInsertionPointToStart(&target->getRegion(0).front()); rewriter.create(target->getLoc(), TypeRange(), ValueRange(), /*failToVerify=*/true); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestProduceInvalidIR::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTargetMutable(), effects); transform::modifiesPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { std::string opName = this->getOperationName().str() + "_" + getTypeAttr().str(); TransformStateInitializerExtension *initExt = state.getExtension(); if (!initExt) { emitRemark() << "\nSpecified extension not found, adding a new one!\n"; SmallVector opCollection = {opName}; state.addExtension(1, opCollection); } else { initExt->setNumOp(initExt->getNumOp() + 1); initExt->pushRegisteredOps(opName); InFlightDiagnostic diag = emitRemark() << "Number of currently registered op: " << initExt->getNumOp() << "\n" << initExt->printMessage() << "\n"; } return DiagnosedSilenceableFailure::success(); } namespace { /// Test conversion pattern that replaces ops with the "replace_with_new_op" /// attribute with "test.new_op". class ReplaceWithNewOpConversion : public ConversionPattern { public: ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!op->hasAttr("replace_with_new_op")) return failure(); SmallVector newResultTypes; if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), newResultTypes))) return failure(); Operation *newOp = rewriter.create( op->getLoc(), OperationName("test.new_op", op->getContext()).getIdentifier(), operands, newResultTypes); rewriter.replaceOp(op, newOp->getResults()); return success(); } }; } // namespace void mlir::test::ApplyTestConversionPatternsOp::populatePatterns( TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.insert(typeConverter, patterns.getContext()); } namespace { /// Test type converter that converts tensor types to memref types. class TestTypeConverter : public TypeConverter { public: TestTypeConverter() { addConversion([](Type t) { return t; }); addConversion([](RankedTensorType type) -> Type { return MemRefType::get(type.getShape(), type.getElementType()); }); auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) return Value(); return builder.create(loc, resultType, inputs) .getResult(0); }; addSourceMaterialization(unrealizedCastConverter); addTargetMaterialization(unrealizedCastConverter); } }; } // namespace std::unique_ptr<::mlir::TypeConverter> mlir::test::TestTypeConverterOp::getTypeConverter() { return std::make_unique(); } namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL /// types for operands and results. class TestTransformDialectExtension : public transform::TransformDialectExtension< TestTransformDialectExtension> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension) using Base::Base; void init() { declareDependentDialect(); registerTransformOps(); registerTypes< #define GET_TYPEDEF_LIST #include "TestTransformDialectExtensionTypes.cpp.inc" >(); auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &, ArrayRef pdlValues) { for (const PDLValue &pdlValue : pdlValues) { if (Operation *op = pdlValue.dyn_cast()) { op->emitWarning() << "from PDL constraint"; } } return success(); }; addDialectDataInitializer( [&](transform::PDLMatchHooks &hooks) { llvm::StringMap constraints; constraints.try_emplace("verbose_constraint", verboseConstraint); hooks.mergeInPDLMatchHooks(std::move(constraints)); }); } }; } // namespace // These are automatically generated by ODS but are not used as the Transform // dialect uses a different dispatch mechanism to support dialect extensions. LLVM_ATTRIBUTE_UNUSED static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); LLVM_ATTRIBUTE_UNUSED static LogicalResult generatedTypePrinter(Type def, AsmPrinter &printer); #define GET_TYPEDEF_CLASSES #include "TestTransformDialectExtensionTypes.cpp.inc" #define GET_OP_CLASSES #include "TestTransformDialectExtension.cpp.inc" void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { registry.addExtensions(); }