1d064c480SAlex Zinenko //===- TestTransformDialectExtension.cpp ----------------------------------===// 2d064c480SAlex Zinenko // 3d064c480SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d064c480SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5d064c480SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d064c480SAlex Zinenko // 7d064c480SAlex Zinenko //===----------------------------------------------------------------------===// 8d064c480SAlex Zinenko // 9d064c480SAlex Zinenko // This file defines an extension of the MLIR Transform dialect for testing 10d064c480SAlex Zinenko // purposes. 11d064c480SAlex Zinenko // 12d064c480SAlex Zinenko //===----------------------------------------------------------------------===// 13d064c480SAlex Zinenko 14d064c480SAlex Zinenko #include "TestTransformDialectExtension.h" 156c57b0deSAlex Zinenko #include "TestTransformStateExtension.h" 16d064c480SAlex Zinenko #include "mlir/Dialect/PDL/IR/PDL.h" 17d064c480SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h" 180e37ef08SMatthias Springer #include "mlir/Dialect/Transform/IR/TransformOps.h" 195a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 2094d608d4SAlex Zinenko #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" 21d064c480SAlex Zinenko #include "mlir/IR/OpImplementation.h" 2294d608d4SAlex Zinenko #include "mlir/IR/PatternMatch.h" 23ed02fa81SAlex Zinenko #include "llvm/ADT/STLExtras.h" 24bba85ebdSAlex Zinenko #include "llvm/ADT/TypeSwitch.h" 25bba85ebdSAlex Zinenko #include "llvm/Support/Compiler.h" 26ed02fa81SAlex Zinenko #include "llvm/Support/raw_ostream.h" 27d064c480SAlex Zinenko 28d064c480SAlex Zinenko using namespace mlir; 29d064c480SAlex Zinenko 30d064c480SAlex Zinenko namespace { 31d064c480SAlex Zinenko /// Simple transform op defined outside of the dialect. Just emits a remark when 3230f22429SAlex Zinenko /// applied. This op is defined in C++ to test that C++ definitions also work 3330f22429SAlex Zinenko /// for op injection into the Transform dialect. 34d064c480SAlex Zinenko class TestTransformOp 3540a8bd63SAlex Zinenko : public Op<TestTransformOp, transform::TransformOpInterface::Trait, 3640a8bd63SAlex Zinenko MemoryEffectOpInterface::Trait> { 37d064c480SAlex Zinenko public: 38d064c480SAlex Zinenko MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) 39d064c480SAlex Zinenko 40d064c480SAlex Zinenko using Op::Op; 41d064c480SAlex Zinenko 42d064c480SAlex Zinenko static ArrayRef<StringRef> getAttributeNames() { return {}; } 43d064c480SAlex Zinenko 44d064c480SAlex Zinenko static constexpr llvm::StringLiteral getOperationName() { 45d064c480SAlex Zinenko return llvm::StringLiteral("transform.test_transform_op"); 46d064c480SAlex Zinenko } 47d064c480SAlex Zinenko 48c63d2b2cSMatthias Springer DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, 49c63d2b2cSMatthias Springer transform::TransformResults &results, 50d064c480SAlex Zinenko transform::TransformState &state) { 510eb403adSAlex Zinenko InFlightDiagnostic remark = emitRemark() << "applying transformation"; 520eb403adSAlex Zinenko if (Attribute message = getMessage()) 530eb403adSAlex Zinenko remark << " " << message; 540eb403adSAlex Zinenko 551d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 56d064c480SAlex Zinenko } 57d064c480SAlex Zinenko 58830b9b07SMehdi Amini Attribute getMessage() { 59830b9b07SMehdi Amini return getOperation()->getDiscardableAttr("message"); 60830b9b07SMehdi Amini } 610eb403adSAlex Zinenko 62d064c480SAlex Zinenko static ParseResult parse(OpAsmParser &parser, OperationState &state) { 630eb403adSAlex Zinenko StringAttr message; 640eb403adSAlex Zinenko OptionalParseResult result = parser.parseOptionalAttribute(message); 659750648cSKazu Hirata if (!result.has_value()) 66d064c480SAlex Zinenko return success(); 670eb403adSAlex Zinenko 68c8e6ebd7SKazu Hirata if (result.value().succeeded()) 690eb403adSAlex Zinenko state.addAttribute("message", message); 70c8e6ebd7SKazu Hirata return result.value(); 71d064c480SAlex Zinenko } 72d064c480SAlex Zinenko 730eb403adSAlex Zinenko void print(OpAsmPrinter &printer) { 740eb403adSAlex Zinenko if (getMessage()) 750eb403adSAlex Zinenko printer << " " << getMessage(); 760eb403adSAlex Zinenko } 7740a8bd63SAlex Zinenko 7840a8bd63SAlex Zinenko // No side effects. 7940a8bd63SAlex Zinenko void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 80d064c480SAlex Zinenko }; 8130f22429SAlex Zinenko 8230f22429SAlex Zinenko /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait 8330f22429SAlex Zinenko /// in cases where it is attached to ops that do not comply with the trait 8430f22429SAlex Zinenko /// requirements. This op cannot be defined in ODS because ODS generates strict 8530f22429SAlex Zinenko /// verifiers that overalp with those in the trait and run earlier. 8630f22429SAlex Zinenko class TestTransformUnrestrictedOpNoInterface 8730f22429SAlex Zinenko : public Op<TestTransformUnrestrictedOpNoInterface, 8830f22429SAlex Zinenko transform::PossibleTopLevelTransformOpTrait, 8940a8bd63SAlex Zinenko transform::TransformOpInterface::Trait, 9040a8bd63SAlex Zinenko MemoryEffectOpInterface::Trait> { 9130f22429SAlex Zinenko public: 9230f22429SAlex Zinenko MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 9330f22429SAlex Zinenko TestTransformUnrestrictedOpNoInterface) 9430f22429SAlex Zinenko 9530f22429SAlex Zinenko using Op::Op; 9630f22429SAlex Zinenko 9730f22429SAlex Zinenko static ArrayRef<StringRef> getAttributeNames() { return {}; } 9830f22429SAlex Zinenko 9930f22429SAlex Zinenko static constexpr llvm::StringLiteral getOperationName() { 10030f22429SAlex Zinenko return llvm::StringLiteral( 10130f22429SAlex Zinenko "transform.test_transform_unrestricted_op_no_interface"); 10230f22429SAlex Zinenko } 10330f22429SAlex Zinenko 104c63d2b2cSMatthias Springer DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, 105c63d2b2cSMatthias Springer transform::TransformResults &results, 10630f22429SAlex Zinenko transform::TransformState &state) { 1071d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 10830f22429SAlex Zinenko } 10940a8bd63SAlex Zinenko 11040a8bd63SAlex Zinenko // No side effects. 11140a8bd63SAlex Zinenko void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 11230f22429SAlex Zinenko }; 113d064c480SAlex Zinenko } // namespace 114d064c480SAlex Zinenko 1151d45282aSAlex Zinenko DiagnosedSilenceableFailure 116a7026288SAlex Zinenko mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( 117c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 118d064c480SAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 119d064c480SAlex Zinenko if (getOperation()->getNumOperands() != 0) { 1200e37ef08SMatthias Springer results.set(cast<OpResult>(getResult()), 1210e37ef08SMatthias Springer {getOperation()->getOperand(0).getDefiningOp()}); 122d064c480SAlex Zinenko } else { 1230e37ef08SMatthias Springer results.set(cast<OpResult>(getResult()), {getOperation()}); 124d064c480SAlex Zinenko } 1251d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 126d064c480SAlex Zinenko } 127d064c480SAlex Zinenko 128a7026288SAlex Zinenko void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects( 1290242b962SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1300242b962SAlex Zinenko if (getOperand()) 1312c1ae801Sdonald chen transform::onlyReadsHandle(getOperandMutable(), effects); 1322c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 1330242b962SAlex Zinenko } 1340242b962SAlex Zinenko 135a7026288SAlex Zinenko DiagnosedSilenceableFailure 136a7026288SAlex Zinenko mlir::test::TestProduceValueHandleToSelfOperand::apply( 137c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 138a7026288SAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 139085075a5SMatthias Springer results.setValues(llvm::cast<OpResult>(getOut()), {getIn()}); 140a7026288SAlex Zinenko return DiagnosedSilenceableFailure::success(); 141a7026288SAlex Zinenko } 142a7026288SAlex Zinenko 143a7026288SAlex Zinenko void mlir::test::TestProduceValueHandleToSelfOperand::getEffects( 144a7026288SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1452c1ae801Sdonald chen transform::onlyReadsHandle(getInMutable(), effects); 1462c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 147a7026288SAlex Zinenko transform::onlyReadsPayload(effects); 148a7026288SAlex Zinenko } 149a7026288SAlex Zinenko 150a7026288SAlex Zinenko DiagnosedSilenceableFailure 151a7026288SAlex Zinenko mlir::test::TestProduceValueHandleToResult::applyToOne( 152c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 153c63d2b2cSMatthias Springer transform::ApplyToEachResultList &results, 154a7026288SAlex Zinenko transform::TransformState &state) { 155a7026288SAlex Zinenko if (target->getNumResults() <= getNumber()) 156a7026288SAlex Zinenko return emitSilenceableError() << "payload has no result #" << getNumber(); 157a7026288SAlex Zinenko results.push_back(target->getResult(getNumber())); 158a7026288SAlex Zinenko return DiagnosedSilenceableFailure::success(); 159a7026288SAlex Zinenko } 160a7026288SAlex Zinenko 161a7026288SAlex Zinenko void mlir::test::TestProduceValueHandleToResult::getEffects( 162a7026288SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1632c1ae801Sdonald chen transform::onlyReadsHandle(getInMutable(), effects); 1642c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 165a7026288SAlex Zinenko transform::onlyReadsPayload(effects); 166a7026288SAlex Zinenko } 167a7026288SAlex Zinenko 168a7026288SAlex Zinenko DiagnosedSilenceableFailure 169a7026288SAlex Zinenko mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( 170c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 171c63d2b2cSMatthias Springer transform::ApplyToEachResultList &results, 172a7026288SAlex Zinenko transform::TransformState &state) { 173a7026288SAlex Zinenko if (!target->getBlock()) 174a7026288SAlex Zinenko return emitSilenceableError() << "payload has no parent block"; 175a7026288SAlex Zinenko if (target->getBlock()->getNumArguments() <= getNumber()) 176a7026288SAlex Zinenko return emitSilenceableError() 177a7026288SAlex Zinenko << "parent of the payload has no argument #" << getNumber(); 178a7026288SAlex Zinenko results.push_back(target->getBlock()->getArgument(getNumber())); 179a7026288SAlex Zinenko return DiagnosedSilenceableFailure::success(); 180a7026288SAlex Zinenko } 181a7026288SAlex Zinenko 182a7026288SAlex Zinenko void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects( 183a7026288SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1842c1ae801Sdonald chen transform::onlyReadsHandle(getInMutable(), effects); 1852c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 186a7026288SAlex Zinenko transform::onlyReadsPayload(effects); 187d064c480SAlex Zinenko } 188d064c480SAlex Zinenko 1891d9a1139SAlex Zinenko bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() { 1901d9a1139SAlex Zinenko return getAllowRepeatedHandles(); 1911d9a1139SAlex Zinenko } 1921d9a1139SAlex Zinenko 1931d45282aSAlex Zinenko DiagnosedSilenceableFailure 194c63d2b2cSMatthias Springer mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter, 195c63d2b2cSMatthias Springer transform::TransformResults &results, 1966403e1b1SAlex Zinenko transform::TransformState &state) { 1971d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 1986403e1b1SAlex Zinenko } 1996403e1b1SAlex Zinenko 2000242b962SAlex Zinenko void mlir::test::TestConsumeOperand::getEffects( 2010242b962SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2022c1ae801Sdonald chen transform::consumesHandle(getOperation()->getOpOperands(), effects); 2030242b962SAlex Zinenko if (getSecondOperand()) 2042c1ae801Sdonald chen transform::consumesHandle(getSecondOperandMutable(), effects); 2050242b962SAlex Zinenko transform::modifiesPayload(effects); 2060242b962SAlex Zinenko } 2070242b962SAlex Zinenko 208a7026288SAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( 209c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 210d064c480SAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 2110e37ef08SMatthias Springer auto payload = state.getPayloadOps(getOperand()); 2120e37ef08SMatthias Springer assert(llvm::hasSingleElement(payload) && "expected a single target op"); 2130e37ef08SMatthias Springer if ((*payload.begin())->getName().getStringRef() != getOpKind()) { 2141d45282aSAlex Zinenko return emitSilenceableError() 215a7026288SAlex Zinenko << "op expected the operand to be associated a payload op of kind " 2160e37ef08SMatthias Springer << getOpKind() << " got " 2170e37ef08SMatthias Springer << (*payload.begin())->getName().getStringRef(); 218d064c480SAlex Zinenko } 219d064c480SAlex Zinenko 220d064c480SAlex Zinenko emitRemark() << "succeeded"; 2211d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 222d064c480SAlex Zinenko } 223d064c480SAlex Zinenko 224a7026288SAlex Zinenko void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects( 2250242b962SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2262c1ae801Sdonald chen transform::consumesHandle(getOperation()->getOpOperands(), effects); 2270242b962SAlex Zinenko transform::modifiesPayload(effects); 2280242b962SAlex Zinenko } 2290242b962SAlex Zinenko 23063c9d2b1SAlex Zinenko DiagnosedSilenceableFailure 23163c9d2b1SAlex Zinenko mlir::test::TestSucceedIfOperandOfOpKind::matchOperation( 23263c9d2b1SAlex Zinenko Operation *op, transform::TransformResults &results, 23363c9d2b1SAlex Zinenko transform::TransformState &state) { 23463c9d2b1SAlex Zinenko if (op->getName().getStringRef() != getOpKind()) { 23563c9d2b1SAlex Zinenko return emitSilenceableError() 23663c9d2b1SAlex Zinenko << "op expected the operand to be associated with a payload op of " 23763c9d2b1SAlex Zinenko "kind " 23863c9d2b1SAlex Zinenko << getOpKind() << " got " << op->getName().getStringRef(); 23963c9d2b1SAlex Zinenko } 24063c9d2b1SAlex Zinenko return DiagnosedSilenceableFailure::success(); 24163c9d2b1SAlex Zinenko } 24263c9d2b1SAlex Zinenko 24363c9d2b1SAlex Zinenko void mlir::test::TestSucceedIfOperandOfOpKind::getEffects( 24463c9d2b1SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2452c1ae801Sdonald chen transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); 24663c9d2b1SAlex Zinenko transform::onlyReadsPayload(effects); 24763c9d2b1SAlex Zinenko } 24863c9d2b1SAlex Zinenko 249c63d2b2cSMatthias Springer DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply( 250c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 251c63d2b2cSMatthias Springer transform::TransformResults &results, transform::TransformState &state) { 2526c57b0deSAlex Zinenko state.addExtension<TestTransformStateExtension>(getMessageAttr()); 2531d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 2546c57b0deSAlex Zinenko } 2556c57b0deSAlex Zinenko 2561d45282aSAlex Zinenko DiagnosedSilenceableFailure 2571d45282aSAlex Zinenko mlir::test::TestCheckIfTestExtensionPresentOp::apply( 258c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 2596c57b0deSAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 2606c57b0deSAlex Zinenko auto *extension = state.getExtension<TestTransformStateExtension>(); 2616c57b0deSAlex Zinenko if (!extension) { 2626c57b0deSAlex Zinenko emitRemark() << "extension absent"; 2631d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 2646c57b0deSAlex Zinenko } 2656c57b0deSAlex Zinenko 2666c57b0deSAlex Zinenko InFlightDiagnostic diag = emitRemark() 2676c57b0deSAlex Zinenko << "extension present, " << extension->getMessage(); 2686c57b0deSAlex Zinenko for (Operation *payload : state.getPayloadOps(getOperand())) { 2696c57b0deSAlex Zinenko diag.attachNote(payload->getLoc()) << "associated payload op"; 2703dfea727SAlex Zinenko #ifndef NDEBUG 2713dfea727SAlex Zinenko SmallVector<Value> handles; 2723dfea727SAlex Zinenko assert(succeeded(state.getHandlesForPayloadOp(payload, handles))); 2733dfea727SAlex Zinenko assert(llvm::is_contained(handles, getOperand()) && 2746c57b0deSAlex Zinenko "inconsistent mapping between transform IR handles and payload IR " 2756c57b0deSAlex Zinenko "operations"); 2763dfea727SAlex Zinenko #endif // NDEBUG 2776c57b0deSAlex Zinenko } 2786c57b0deSAlex Zinenko 2791d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 2806c57b0deSAlex Zinenko } 2816c57b0deSAlex Zinenko 2820242b962SAlex Zinenko void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects( 2830242b962SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2842c1ae801Sdonald chen transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); 2850242b962SAlex Zinenko transform::onlyReadsPayload(effects); 2860242b962SAlex Zinenko } 2870242b962SAlex Zinenko 2881d45282aSAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( 289c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 2906c57b0deSAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 2916c57b0deSAlex Zinenko auto *extension = state.getExtension<TestTransformStateExtension>(); 292b0bf7fffSAlex Zinenko if (!extension) 293b0bf7fffSAlex Zinenko return emitDefiniteFailure("TestTransformStateExtension missing"); 2946c57b0deSAlex Zinenko 2950e37ef08SMatthias Springer if (failed(extension->updateMapping( 2960e37ef08SMatthias Springer *state.getPayloadOps(getOperand()).begin(), getOperation()))) 2971d45282aSAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 298d9db5a59SAlex Zinenko if (getNumResults() > 0) 2990e37ef08SMatthias Springer results.set(cast<OpResult>(getResult(0)), {getOperation()}); 3001d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 301e3890b7fSAlex Zinenko } 302e3890b7fSAlex Zinenko 3030242b962SAlex Zinenko void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects( 3040242b962SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 3052c1ae801Sdonald chen transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); 3062c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 3070242b962SAlex Zinenko transform::onlyReadsPayload(effects); 3080242b962SAlex Zinenko } 3090242b962SAlex Zinenko 3101d45282aSAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( 311c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 3126c57b0deSAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 3136c57b0deSAlex Zinenko state.removeExtension<TestTransformStateExtension>(); 3141d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 3156c57b0deSAlex Zinenko } 316a299539aSMatthias Springer 317c63d2b2cSMatthias Springer DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply( 318c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 319c63d2b2cSMatthias Springer transform::TransformResults &results, transform::TransformState &state) { 3200e37ef08SMatthias Springer auto payloadOps = state.getPayloadOps(getTarget()); 321a299539aSMatthias Springer auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); 322c1fa60b4STres Popp results.set(llvm::cast<OpResult>(getResult()), reversedOps); 323a299539aSMatthias Springer return DiagnosedSilenceableFailure::success(); 324a299539aSMatthias Springer } 325a299539aSMatthias Springer 3261d45282aSAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( 327c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 32873c3dff1SAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 3291d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 33073c3dff1SAlex Zinenko } 33173c3dff1SAlex Zinenko 33273c3dff1SAlex Zinenko void mlir::test::TestTransformOpWithRegions::getEffects( 33373c3dff1SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 33473c3dff1SAlex Zinenko 3351d45282aSAlex Zinenko DiagnosedSilenceableFailure 336e3890b7fSAlex Zinenko mlir::test::TestBranchingTransformOpTerminator::apply( 337c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 33873c3dff1SAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 3391d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 34073c3dff1SAlex Zinenko } 34173c3dff1SAlex Zinenko 34273c3dff1SAlex Zinenko void mlir::test::TestBranchingTransformOpTerminator::getEffects( 34373c3dff1SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 3446c57b0deSAlex Zinenko 3451d45282aSAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( 346c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 347e3890b7fSAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 348e3890b7fSAlex Zinenko emitRemark() << getRemark(); 349e3890b7fSAlex Zinenko for (Operation *op : state.getPayloadOps(getTarget())) 3504f63252dSMatthias Springer rewriter.eraseOp(op); 351e3890b7fSAlex Zinenko 352e3890b7fSAlex Zinenko if (getFailAfterErase()) 353a60ed954SAlex Zinenko return emitSilenceableError() << "silenceable error"; 3541d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 355e3890b7fSAlex Zinenko } 356e3890b7fSAlex Zinenko 35707fef178SMatthias Springer void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects( 35807fef178SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 3592c1ae801Sdonald chen transform::consumesHandle(getTargetMutable(), effects); 36007fef178SMatthias Springer transform::modifiesPayload(effects); 36107fef178SMatthias Springer } 36207fef178SMatthias Springer 36352307109SNicolas Vasilache DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( 364c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 365c63d2b2cSMatthias Springer transform::ApplyToEachResultList &results, 36652307109SNicolas Vasilache transform::TransformState &state) { 36752307109SNicolas Vasilache OperationState opState(target->getLoc(), "foo"); 36852307109SNicolas Vasilache results.push_back(OpBuilder(target).create(opState)); 36952307109SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 370f439b319SNicolas Vasilache } 371f439b319SNicolas Vasilache 37252307109SNicolas Vasilache DiagnosedSilenceableFailure 3734c7225d1SNicolas Vasilache mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( 374c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 375c63d2b2cSMatthias Springer transform::ApplyToEachResultList &results, 37652307109SNicolas Vasilache transform::TransformState &state) { 3774c7225d1SNicolas Vasilache static int count = 0; 37852307109SNicolas Vasilache if (count++ == 0) { 37952307109SNicolas Vasilache OperationState opState(target->getLoc(), "foo"); 38052307109SNicolas Vasilache results.push_back(OpBuilder(target).create(opState)); 38152307109SNicolas Vasilache } 38252307109SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 3834c7225d1SNicolas Vasilache } 3844c7225d1SNicolas Vasilache 38552307109SNicolas Vasilache DiagnosedSilenceableFailure 3864c7225d1SNicolas Vasilache mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( 387c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 388c63d2b2cSMatthias Springer transform::ApplyToEachResultList &results, 38952307109SNicolas Vasilache transform::TransformState &state) { 39052307109SNicolas Vasilache OperationState opState(target->getLoc(), "foo"); 39152307109SNicolas Vasilache results.push_back(OpBuilder(target).create(opState)); 39252307109SNicolas Vasilache results.push_back(OpBuilder(target).create(opState)); 39352307109SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 39452307109SNicolas Vasilache } 39552307109SNicolas Vasilache 39652307109SNicolas Vasilache DiagnosedSilenceableFailure 39752307109SNicolas Vasilache mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( 398c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 399c63d2b2cSMatthias Springer transform::ApplyToEachResultList &results, 40052307109SNicolas Vasilache transform::TransformState &state) { 40152307109SNicolas Vasilache OperationState opState(target->getLoc(), "foo"); 40252307109SNicolas Vasilache results.push_back(nullptr); 40352307109SNicolas Vasilache results.push_back(OpBuilder(target).create(opState)); 40452307109SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 4054c7225d1SNicolas Vasilache } 4064c7225d1SNicolas Vasilache 40769c8319eSNicolas Vasilache DiagnosedSilenceableFailure 40869c8319eSNicolas Vasilache mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( 409c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 410c63d2b2cSMatthias Springer transform::ApplyToEachResultList &results, 41169c8319eSNicolas Vasilache transform::TransformState &state) { 41269c8319eSNicolas Vasilache if (target->hasAttr("target_me")) 41369c8319eSNicolas Vasilache return DiagnosedSilenceableFailure::success(); 41469c8319eSNicolas Vasilache return emitDefaultSilenceableFailure(target); 41569c8319eSNicolas Vasilache } 41669c8319eSNicolas Vasilache 41700d1a1a2SAlex Zinenko DiagnosedSilenceableFailure 418c63d2b2cSMatthias Springer mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter, 419c63d2b2cSMatthias Springer transform::TransformResults &results, 4203dfea727SAlex Zinenko transform::TransformState &state) { 421c1fa60b4STres Popp results.set(llvm::cast<OpResult>(getCopy()), 422c1fa60b4STres Popp state.getPayloadOps(getHandle())); 4233dfea727SAlex Zinenko return DiagnosedSilenceableFailure::success(); 4243dfea727SAlex Zinenko } 4253dfea727SAlex Zinenko 4260242b962SAlex Zinenko void mlir::test::TestCopyPayloadOp::getEffects( 4270242b962SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 4282c1ae801Sdonald chen transform::onlyReadsHandle(getHandleMutable(), effects); 4292c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 4300242b962SAlex Zinenko transform::onlyReadsPayload(effects); 4310242b962SAlex Zinenko } 4320242b962SAlex Zinenko 433bba85ebdSAlex Zinenko DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload( 434bba85ebdSAlex Zinenko Location loc, ArrayRef<Operation *> payload) const { 435bba85ebdSAlex Zinenko if (payload.empty()) 436bba85ebdSAlex Zinenko return DiagnosedSilenceableFailure::success(); 437bba85ebdSAlex Zinenko 438bba85ebdSAlex Zinenko for (Operation *op : payload) { 439bba85ebdSAlex Zinenko if (op->getName().getDialectNamespace() != "test") { 440ed02fa81SAlex Zinenko return emitSilenceableError(loc) << "expected the payload operation to " 441ed02fa81SAlex Zinenko "belong to the 'test' dialect"; 442bba85ebdSAlex Zinenko } 443bba85ebdSAlex Zinenko } 444bba85ebdSAlex Zinenko 445bba85ebdSAlex Zinenko return DiagnosedSilenceableFailure::success(); 446bba85ebdSAlex Zinenko } 447bba85ebdSAlex Zinenko 448ed02fa81SAlex Zinenko DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload( 449ed02fa81SAlex Zinenko Location loc, ArrayRef<Attribute> payload) const { 450ed02fa81SAlex Zinenko for (Attribute attr : payload) { 451c1fa60b4STres Popp auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr); 452ed02fa81SAlex Zinenko if (integerAttr && integerAttr.getType().isSignlessInteger(32)) 453ed02fa81SAlex Zinenko continue; 454ed02fa81SAlex Zinenko return emitSilenceableError(loc) 455ed02fa81SAlex Zinenko << "expected the parameter to be a i32 integer attribute"; 456ed02fa81SAlex Zinenko } 457ed02fa81SAlex Zinenko 458ed02fa81SAlex Zinenko return DiagnosedSilenceableFailure::success(); 459ed02fa81SAlex Zinenko } 460ed02fa81SAlex Zinenko 461d8cab3f4SNicolas Vasilache void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects( 462d8cab3f4SNicolas Vasilache SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 4632c1ae801Sdonald chen transform::onlyReadsHandle(getTargetMutable(), effects); 464d8cab3f4SNicolas Vasilache } 465d8cab3f4SNicolas Vasilache 466d8cab3f4SNicolas Vasilache DiagnosedSilenceableFailure 467d8cab3f4SNicolas Vasilache mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply( 468c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 469d8cab3f4SNicolas Vasilache transform::TransformResults &results, transform::TransformState &state) { 470d8cab3f4SNicolas Vasilache int64_t count = 0; 471d8cab3f4SNicolas Vasilache for (Operation *op : state.getPayloadOps(getTarget())) { 472d8cab3f4SNicolas Vasilache op->walk([&](Operation *nested) { 473d8cab3f4SNicolas Vasilache SmallVector<Value> handles; 474d8cab3f4SNicolas Vasilache (void)state.getHandlesForPayloadOp(nested, handles); 475d8cab3f4SNicolas Vasilache count += handles.size(); 476d8cab3f4SNicolas Vasilache }); 477d8cab3f4SNicolas Vasilache } 478d8cab3f4SNicolas Vasilache emitRemark() << count << " handles nested under"; 479d8cab3f4SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 480d8cab3f4SNicolas Vasilache } 481d8cab3f4SNicolas Vasilache 482ed02fa81SAlex Zinenko DiagnosedSilenceableFailure 483c63d2b2cSMatthias Springer mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter, 484c63d2b2cSMatthias Springer transform::TransformResults &results, 485ed02fa81SAlex Zinenko transform::TransformState &state) { 486ed02fa81SAlex Zinenko SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0); 487ed02fa81SAlex Zinenko if (Value param = getParam()) { 488ed02fa81SAlex Zinenko values = llvm::to_vector( 489ed02fa81SAlex Zinenko llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t { 490c1fa60b4STres Popp return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue( 491ed02fa81SAlex Zinenko UINT32_MAX); 492ed02fa81SAlex Zinenko })); 493ed02fa81SAlex Zinenko } 494ed02fa81SAlex Zinenko 495ed02fa81SAlex Zinenko Builder builder(getContext()); 496ed02fa81SAlex Zinenko SmallVector<Attribute> result = llvm::to_vector( 497ed02fa81SAlex Zinenko llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute { 498ed02fa81SAlex Zinenko return builder.getI32IntegerAttr(value + getAddendum()); 499ed02fa81SAlex Zinenko })); 500c1fa60b4STres Popp results.setParams(llvm::cast<OpResult>(getResult()), result); 501ed02fa81SAlex Zinenko return DiagnosedSilenceableFailure::success(); 502ed02fa81SAlex Zinenko } 503ed02fa81SAlex Zinenko 504ed02fa81SAlex Zinenko DiagnosedSilenceableFailure 505ed02fa81SAlex Zinenko mlir::test::TestProduceParamWithNumberOfTestOps::apply( 506c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 507ed02fa81SAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 508ed02fa81SAlex Zinenko Builder builder(getContext()); 509ed02fa81SAlex Zinenko SmallVector<Attribute> result = llvm::to_vector( 510ed02fa81SAlex Zinenko llvm::map_range(state.getPayloadOps(getHandle()), 511ed02fa81SAlex Zinenko [&builder](Operation *payload) -> Attribute { 512ed02fa81SAlex Zinenko int32_t count = 0; 513ed02fa81SAlex Zinenko payload->walk([&count](Operation *op) { 514ed02fa81SAlex Zinenko if (op->getName().getDialectNamespace() == "test") 515ed02fa81SAlex Zinenko ++count; 516ed02fa81SAlex Zinenko }); 517ed02fa81SAlex Zinenko return builder.getI32IntegerAttr(count); 518ed02fa81SAlex Zinenko })); 519c1fa60b4STres Popp results.setParams(llvm::cast<OpResult>(getResult()), result); 520ed02fa81SAlex Zinenko return DiagnosedSilenceableFailure::success(); 521ed02fa81SAlex Zinenko } 522ed02fa81SAlex Zinenko 523ed02fa81SAlex Zinenko DiagnosedSilenceableFailure 524dd81c6b8SAlex Zinenko mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter, 525dd81c6b8SAlex Zinenko transform::TransformResults &results, 526dd81c6b8SAlex Zinenko transform::TransformState &state) { 527dd81c6b8SAlex Zinenko results.setParams(llvm::cast<OpResult>(getResult()), getAttr()); 528ed02fa81SAlex Zinenko return DiagnosedSilenceableFailure::success(); 529ed02fa81SAlex Zinenko } 530ed02fa81SAlex Zinenko 5314b455a71SAlex Zinenko void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects( 5324b455a71SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 5332c1ae801Sdonald chen transform::onlyReadsHandle(getInMutable(), effects); 5342c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 5354b455a71SAlex Zinenko } 5364b455a71SAlex Zinenko 5374b455a71SAlex Zinenko DiagnosedSilenceableFailure 5384b455a71SAlex Zinenko mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( 539c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 540c63d2b2cSMatthias Springer ::transform::ApplyToEachResultList &results, 5414b455a71SAlex Zinenko ::transform::TransformState &state) { 5424b455a71SAlex Zinenko Builder builder(getContext()); 5434b455a71SAlex Zinenko if (getFirstResultIsParam()) { 5444b455a71SAlex Zinenko results.push_back(builder.getI64IntegerAttr(0)); 5454b455a71SAlex Zinenko } else if (getFirstResultIsNull()) { 5464b455a71SAlex Zinenko results.push_back(nullptr); 5474b455a71SAlex Zinenko } else { 5480e37ef08SMatthias Springer results.push_back(*state.getPayloadOps(getIn()).begin()); 5494b455a71SAlex Zinenko } 5504b455a71SAlex Zinenko 5514b455a71SAlex Zinenko if (getSecondResultIsHandle()) { 5520e37ef08SMatthias Springer results.push_back(*state.getPayloadOps(getIn()).begin()); 5534b455a71SAlex Zinenko } else { 5544b455a71SAlex Zinenko results.push_back(builder.getI64IntegerAttr(42)); 5554b455a71SAlex Zinenko } 5564b455a71SAlex Zinenko 5574b455a71SAlex Zinenko return DiagnosedSilenceableFailure::success(); 5584b455a71SAlex Zinenko } 5594b455a71SAlex Zinenko 560984c2c8cSAlex Zinenko void mlir::test::TestProduceNullPayloadOp::getEffects( 561984c2c8cSAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 5622c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 563984c2c8cSAlex Zinenko } 564984c2c8cSAlex Zinenko 565984c2c8cSAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( 566c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 567984c2c8cSAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 568984c2c8cSAlex Zinenko SmallVector<Operation *, 1> null({nullptr}); 569c1fa60b4STres Popp results.set(llvm::cast<OpResult>(getOut()), null); 570984c2c8cSAlex Zinenko return DiagnosedSilenceableFailure::success(); 571984c2c8cSAlex Zinenko } 572984c2c8cSAlex Zinenko 57344f6e862SAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply( 574c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 57544f6e862SAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 57644f6e862SAlex Zinenko results.set(cast<OpResult>(getOut()), {}); 57744f6e862SAlex Zinenko return DiagnosedSilenceableFailure::success(); 57844f6e862SAlex Zinenko } 57944f6e862SAlex Zinenko 580984c2c8cSAlex Zinenko void mlir::test::TestProduceNullParamOp::getEffects( 58139489284SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 5822c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 58339489284SAlex Zinenko } 584984c2c8cSAlex Zinenko 585c63d2b2cSMatthias Springer DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply( 586c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 587c63d2b2cSMatthias Springer transform::TransformResults &results, transform::TransformState &state) { 588c1fa60b4STres Popp results.setParams(llvm::cast<OpResult>(getOut()), Attribute()); 589984c2c8cSAlex Zinenko return DiagnosedSilenceableFailure::success(); 590984c2c8cSAlex Zinenko } 591984c2c8cSAlex Zinenko 592a7026288SAlex Zinenko void mlir::test::TestProduceNullValueOp::getEffects( 593a7026288SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 5942c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 595a7026288SAlex Zinenko } 596a7026288SAlex Zinenko 597c63d2b2cSMatthias Springer DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply( 598c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 599c63d2b2cSMatthias Springer transform::TransformResults &results, transform::TransformState &state) { 600085075a5SMatthias Springer results.setValues(llvm::cast<OpResult>(getOut()), {Value()}); 601a7026288SAlex Zinenko return DiagnosedSilenceableFailure::success(); 602a7026288SAlex Zinenko } 603a7026288SAlex Zinenko 60439489284SAlex Zinenko void mlir::test::TestRequiredMemoryEffectsOp::getEffects( 60539489284SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 60639489284SAlex Zinenko if (getHasOperandEffect()) 6072c1ae801Sdonald chen transform::consumesHandle(getInMutable(), effects); 60839489284SAlex Zinenko 6092c1ae801Sdonald chen if (getHasResultEffect()) { 6102c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 6112c1ae801Sdonald chen } else { 6122c1ae801Sdonald chen effects.emplace_back(MemoryEffects::Read::get(), 6132c1ae801Sdonald chen llvm::cast<OpResult>(getOut()), 6142c1ae801Sdonald chen transform::TransformMappingResource::get()); 6152c1ae801Sdonald chen } 6160242b962SAlex Zinenko 6170242b962SAlex Zinenko if (getModifiesPayload()) 6180242b962SAlex Zinenko transform::modifiesPayload(effects); 61939489284SAlex Zinenko } 62039489284SAlex Zinenko 62139489284SAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( 622c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, 62339489284SAlex Zinenko transform::TransformResults &results, transform::TransformState &state) { 624c1fa60b4STres Popp results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn())); 62539489284SAlex Zinenko return DiagnosedSilenceableFailure::success(); 62639489284SAlex Zinenko } 62739489284SAlex Zinenko 6280e37ef08SMatthias Springer void mlir::test::TestTrackedRewriteOp::getEffects( 6290e37ef08SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 6302c1ae801Sdonald chen transform::onlyReadsHandle(getInMutable(), effects); 6310e37ef08SMatthias Springer transform::modifiesPayload(effects); 6320e37ef08SMatthias Springer } 6330e37ef08SMatthias Springer 6341b390f5eSMatthias Springer void mlir::test::TestDummyPayloadOp::getEffects( 6351b390f5eSMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 6362c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 6370e37ef08SMatthias Springer } 6380e37ef08SMatthias Springer 6397dfcd4b7SMatthias Springer LogicalResult mlir::test::TestDummyPayloadOp::verify() { 6407dfcd4b7SMatthias Springer if (getFailToVerify()) 6417dfcd4b7SMatthias Springer return emitOpError() << "fail_to_verify is set"; 6427dfcd4b7SMatthias Springer return success(); 6437dfcd4b7SMatthias Springer } 6447dfcd4b7SMatthias Springer 6450e37ef08SMatthias Springer DiagnosedSilenceableFailure 646c63d2b2cSMatthias Springer mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter, 647c63d2b2cSMatthias Springer transform::TransformResults &results, 6480e37ef08SMatthias Springer transform::TransformState &state) { 6490e37ef08SMatthias Springer int64_t numIterations = 0; 6500e37ef08SMatthias Springer 6510e37ef08SMatthias Springer // `getPayloadOps` returns an iterator that skips ops that are erased in the 6520e37ef08SMatthias Springer // loop body. Replacement ops are not enumerated. 6530e37ef08SMatthias Springer for (Operation *op : state.getPayloadOps(getIn())) { 6540e37ef08SMatthias Springer ++numIterations; 6551b390f5eSMatthias Springer (void)op; 6560e37ef08SMatthias Springer 6570e37ef08SMatthias Springer // Erase all payload ops. The outer loop should have only one iteration. 6580e37ef08SMatthias Springer for (Operation *op : state.getPayloadOps(getIn())) { 6591b390f5eSMatthias Springer rewriter.setInsertionPoint(op); 6601b390f5eSMatthias Springer if (op->hasAttr("erase_me")) { 6611b390f5eSMatthias Springer rewriter.eraseOp(op); 6620e37ef08SMatthias Springer continue; 6631b390f5eSMatthias Springer } 6641b390f5eSMatthias Springer if (!op->hasAttr("replace_me")) { 6650e37ef08SMatthias Springer continue; 6661b390f5eSMatthias Springer } 6671b390f5eSMatthias Springer 6680e37ef08SMatthias Springer SmallVector<NamedAttribute> attributes; 6691b390f5eSMatthias Springer attributes.emplace_back(rewriter.getStringAttr("new_op"), 6701b390f5eSMatthias Springer rewriter.getUnitAttr()); 6711b390f5eSMatthias Springer OperationState opState(op->getLoc(), op->getName().getIdentifier(), 6720e37ef08SMatthias Springer /*operands=*/ValueRange(), 6730e37ef08SMatthias Springer /*types=*/op->getResultTypes(), attributes); 6740e37ef08SMatthias Springer Operation *newOp = rewriter.create(opState); 6750e37ef08SMatthias Springer rewriter.replaceOp(op, newOp->getResults()); 6760e37ef08SMatthias Springer } 6770e37ef08SMatthias Springer } 6780e37ef08SMatthias Springer 6790e37ef08SMatthias Springer emitRemark() << numIterations << " iterations"; 6800e37ef08SMatthias Springer return DiagnosedSilenceableFailure::success(); 6810e37ef08SMatthias Springer } 6820e37ef08SMatthias Springer 683d064c480SAlex Zinenko namespace { 6840b52fa90SMatthias Springer // Test pattern to replace an operation with a new op. 6850b52fa90SMatthias Springer class ReplaceWithNewOp : public RewritePattern { 6860b52fa90SMatthias Springer public: 6870b52fa90SMatthias Springer ReplaceWithNewOp(MLIRContext *context) 6880b52fa90SMatthias Springer : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 6890b52fa90SMatthias Springer 6900b52fa90SMatthias Springer LogicalResult matchAndRewrite(Operation *op, 6910b52fa90SMatthias Springer PatternRewriter &rewriter) const override { 6920b52fa90SMatthias Springer auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op"); 6930b52fa90SMatthias Springer if (!newName) 6940b52fa90SMatthias Springer return failure(); 6950b52fa90SMatthias Springer Operation *newOp = rewriter.create( 6960b52fa90SMatthias Springer op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(), 6970b52fa90SMatthias Springer op->getOperands(), op->getResultTypes()); 6980b52fa90SMatthias Springer rewriter.replaceOp(op, newOp->getResults()); 6990b52fa90SMatthias Springer return success(); 7000b52fa90SMatthias Springer } 7010b52fa90SMatthias Springer }; 7020b52fa90SMatthias Springer 7030b52fa90SMatthias Springer // Test pattern to erase an operation. 7040b52fa90SMatthias Springer class EraseOp : public RewritePattern { 7050b52fa90SMatthias Springer public: 7060b52fa90SMatthias Springer EraseOp(MLIRContext *context) 7070b52fa90SMatthias Springer : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 7080b52fa90SMatthias Springer LogicalResult matchAndRewrite(Operation *op, 7090b52fa90SMatthias Springer PatternRewriter &rewriter) const override { 7100b52fa90SMatthias Springer rewriter.eraseOp(op); 7110b52fa90SMatthias Springer return success(); 7120b52fa90SMatthias Springer } 7130b52fa90SMatthias Springer }; 7145a10f207SMatthias Springer } // namespace 7150b52fa90SMatthias Springer 7165a10f207SMatthias Springer void mlir::test::ApplyTestPatternsOp::populatePatterns( 7175a10f207SMatthias Springer RewritePatternSet &patterns) { 7185a10f207SMatthias Springer patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext()); 7195a10f207SMatthias Springer } 7205a10f207SMatthias Springer 721c580bd26SAlex Zinenko void mlir::test::TestReEnterRegionOp::getEffects( 722c580bd26SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 7232c1ae801Sdonald chen transform::consumesHandle(getOperation()->getOpOperands(), effects); 724c580bd26SAlex Zinenko transform::modifiesPayload(effects); 725c580bd26SAlex Zinenko } 726c580bd26SAlex Zinenko 727c580bd26SAlex Zinenko DiagnosedSilenceableFailure 728c580bd26SAlex Zinenko mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter, 729c580bd26SAlex Zinenko transform::TransformResults &results, 730c580bd26SAlex Zinenko transform::TransformState &state) { 731c580bd26SAlex Zinenko 732c580bd26SAlex Zinenko SmallVector<SmallVector<transform::MappedValue>> mappings; 733c580bd26SAlex Zinenko for (BlockArgument arg : getBody().front().getArguments()) { 734c580bd26SAlex Zinenko mappings.emplace_back(llvm::to_vector(llvm::map_range( 735c580bd26SAlex Zinenko state.getPayloadOps(getOperand(arg.getArgNumber())), 736c580bd26SAlex Zinenko [](Operation *op) -> transform::MappedValue { return op; }))); 737c580bd26SAlex Zinenko } 738c580bd26SAlex Zinenko 739c580bd26SAlex Zinenko for (int i = 0; i < 4; ++i) { 740c580bd26SAlex Zinenko auto scope = state.make_region_scope(getBody()); 741c580bd26SAlex Zinenko for (BlockArgument arg : getBody().front().getArguments()) { 742c580bd26SAlex Zinenko if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()]))) 743c580bd26SAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 744c580bd26SAlex Zinenko } 745c580bd26SAlex Zinenko for (Operation &op : getBody().front().without_terminator()) { 746c580bd26SAlex Zinenko DiagnosedSilenceableFailure diag = 747c580bd26SAlex Zinenko state.applyTransform(cast<transform::TransformOpInterface>(op)); 748c580bd26SAlex Zinenko if (!diag.succeeded()) 749c580bd26SAlex Zinenko return diag; 750c580bd26SAlex Zinenko } 751c580bd26SAlex Zinenko } 752c580bd26SAlex Zinenko return DiagnosedSilenceableFailure::success(); 753c580bd26SAlex Zinenko } 754c580bd26SAlex Zinenko 755c580bd26SAlex Zinenko LogicalResult mlir::test::TestReEnterRegionOp::verify() { 756c580bd26SAlex Zinenko if (getNumOperands() != getBody().front().getNumArguments()) { 757c580bd26SAlex Zinenko return emitOpError() << "expects as many operands as block arguments"; 758c580bd26SAlex Zinenko } 759c580bd26SAlex Zinenko return success(); 760c580bd26SAlex Zinenko } 761c580bd26SAlex Zinenko 7629cc8e458SMatthias Springer DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply( 7639cc8e458SMatthias Springer transform::TransformRewriter &rewriter, 7649cc8e458SMatthias Springer transform::TransformResults &results, transform::TransformState &state) { 7659cc8e458SMatthias Springer auto originalOps = state.getPayloadOps(getOriginal()); 7669cc8e458SMatthias Springer auto replacementOps = state.getPayloadOps(getReplacement()); 7679cc8e458SMatthias Springer if (llvm::range_size(originalOps) != llvm::range_size(replacementOps)) 7689cc8e458SMatthias Springer return emitSilenceableError() << "expected same number of original and " 7699cc8e458SMatthias Springer "replacement payload operations"; 7709cc8e458SMatthias Springer for (const auto &[original, replacement] : 7719cc8e458SMatthias Springer llvm::zip(originalOps, replacementOps)) { 7729cc8e458SMatthias Springer if (failed( 7739cc8e458SMatthias Springer rewriter.notifyPayloadOperationReplaced(original, replacement))) { 7749cc8e458SMatthias Springer auto diag = emitSilenceableError() 7759cc8e458SMatthias Springer << "unable to replace payload op in transform mapping"; 7769cc8e458SMatthias Springer diag.attachNote(original->getLoc()) << "original payload op"; 7779cc8e458SMatthias Springer diag.attachNote(replacement->getLoc()) << "replacement payload op"; 7789cc8e458SMatthias Springer return diag; 7799cc8e458SMatthias Springer } 7809cc8e458SMatthias Springer } 7819cc8e458SMatthias Springer return DiagnosedSilenceableFailure::success(); 7829cc8e458SMatthias Springer } 7839cc8e458SMatthias Springer 7849cc8e458SMatthias Springer void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects( 7859cc8e458SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 7862c1ae801Sdonald chen transform::onlyReadsHandle(getOriginalMutable(), effects); 7872c1ae801Sdonald chen transform::onlyReadsHandle(getReplacementMutable(), effects); 7889cc8e458SMatthias Springer } 7899cc8e458SMatthias Springer 7907dfcd4b7SMatthias Springer DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne( 7917dfcd4b7SMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 7927dfcd4b7SMatthias Springer transform::ApplyToEachResultList &results, 7937dfcd4b7SMatthias Springer transform::TransformState &state) { 7947dfcd4b7SMatthias Springer // Provide some IR that does not verify. 7957dfcd4b7SMatthias Springer rewriter.setInsertionPointToStart(&target->getRegion(0).front()); 7967dfcd4b7SMatthias Springer rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(), 7977dfcd4b7SMatthias Springer ValueRange(), /*failToVerify=*/true); 7987dfcd4b7SMatthias Springer return DiagnosedSilenceableFailure::success(); 7997dfcd4b7SMatthias Springer } 8007dfcd4b7SMatthias Springer 8017dfcd4b7SMatthias Springer void mlir::test::TestProduceInvalidIR::getEffects( 8027dfcd4b7SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 8032c1ae801Sdonald chen transform::onlyReadsHandle(getTargetMutable(), effects); 8047dfcd4b7SMatthias Springer transform::modifiesPayload(effects); 8057dfcd4b7SMatthias Springer } 8067dfcd4b7SMatthias Springer 8076634d44eSAmy Wang DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply( 8086634d44eSAmy Wang transform::TransformRewriter &rewriter, 8096634d44eSAmy Wang transform::TransformResults &results, transform::TransformState &state) { 8106634d44eSAmy Wang std::string opName = 8116634d44eSAmy Wang this->getOperationName().str() + "_" + getTypeAttr().str(); 8126634d44eSAmy Wang TransformStateInitializerExtension *initExt = 8136634d44eSAmy Wang state.getExtension<TransformStateInitializerExtension>(); 8146634d44eSAmy Wang if (!initExt) { 8156634d44eSAmy Wang emitRemark() << "\nSpecified extension not found, adding a new one!\n"; 8166634d44eSAmy Wang SmallVector<std::string> opCollection = {opName}; 8176634d44eSAmy Wang state.addExtension<TransformStateInitializerExtension>(1, opCollection); 8186634d44eSAmy Wang } else { 8196634d44eSAmy Wang initExt->setNumOp(initExt->getNumOp() + 1); 8206634d44eSAmy Wang initExt->pushRegisteredOps(opName); 8216634d44eSAmy Wang InFlightDiagnostic diag = emitRemark() 8226634d44eSAmy Wang << "Number of currently registered op: " 8236634d44eSAmy Wang << initExt->getNumOp() << "\n" 8246634d44eSAmy Wang << initExt->printMessage() << "\n"; 8256634d44eSAmy Wang } 8266634d44eSAmy Wang return DiagnosedSilenceableFailure::success(); 8276634d44eSAmy Wang } 8286634d44eSAmy Wang 8295a10f207SMatthias Springer namespace { 830bcfdb3e4SMatthias Springer /// Test conversion pattern that replaces ops with the "replace_with_new_op" 831bcfdb3e4SMatthias Springer /// attribute with "test.new_op". 832bcfdb3e4SMatthias Springer class ReplaceWithNewOpConversion : public ConversionPattern { 833bcfdb3e4SMatthias Springer public: 834bcfdb3e4SMatthias Springer ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context) 835bcfdb3e4SMatthias Springer : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(), 836bcfdb3e4SMatthias Springer /*benefit=*/1, context) {} 837bcfdb3e4SMatthias Springer 838bcfdb3e4SMatthias Springer LogicalResult 839bcfdb3e4SMatthias Springer matchAndRewrite(Operation *op, ArrayRef<Value> operands, 840bcfdb3e4SMatthias Springer ConversionPatternRewriter &rewriter) const override { 841bcfdb3e4SMatthias Springer if (!op->hasAttr("replace_with_new_op")) 842bcfdb3e4SMatthias Springer return failure(); 843bcfdb3e4SMatthias Springer SmallVector<Type> newResultTypes; 844bcfdb3e4SMatthias Springer if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), 845bcfdb3e4SMatthias Springer newResultTypes))) 846bcfdb3e4SMatthias Springer return failure(); 847bcfdb3e4SMatthias Springer Operation *newOp = rewriter.create( 848bcfdb3e4SMatthias Springer op->getLoc(), 849bcfdb3e4SMatthias Springer OperationName("test.new_op", op->getContext()).getIdentifier(), 850bcfdb3e4SMatthias Springer operands, newResultTypes); 851bcfdb3e4SMatthias Springer rewriter.replaceOp(op, newOp->getResults()); 852bcfdb3e4SMatthias Springer return success(); 853bcfdb3e4SMatthias Springer } 854bcfdb3e4SMatthias Springer }; 855bcfdb3e4SMatthias Springer } // namespace 856bcfdb3e4SMatthias Springer 857bcfdb3e4SMatthias Springer void mlir::test::ApplyTestConversionPatternsOp::populatePatterns( 858bcfdb3e4SMatthias Springer TypeConverter &typeConverter, RewritePatternSet &patterns) { 859bcfdb3e4SMatthias Springer patterns.insert<ReplaceWithNewOpConversion>(typeConverter, 860bcfdb3e4SMatthias Springer patterns.getContext()); 861bcfdb3e4SMatthias Springer } 862bcfdb3e4SMatthias Springer 863bcfdb3e4SMatthias Springer namespace { 864bcfdb3e4SMatthias Springer /// Test type converter that converts tensor types to memref types. 865bcfdb3e4SMatthias Springer class TestTypeConverter : public TypeConverter { 866bcfdb3e4SMatthias Springer public: 867bcfdb3e4SMatthias Springer TestTypeConverter() { 868e2d39f79SChristopher Bate addConversion([](Type t) { return t; }); 869bcfdb3e4SMatthias Springer addConversion([](RankedTensorType type) -> Type { 870bcfdb3e4SMatthias Springer return MemRefType::get(type.getShape(), type.getElementType()); 871bcfdb3e4SMatthias Springer }); 872e2d39f79SChristopher Bate auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType, 873bcfdb3e4SMatthias Springer ValueRange inputs, 874*f18c3e4eSMatthias Springer Location loc) -> Value { 875bcfdb3e4SMatthias Springer if (inputs.size() != 1) 876*f18c3e4eSMatthias Springer return Value(); 877bcfdb3e4SMatthias Springer return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) 878bcfdb3e4SMatthias Springer .getResult(0); 879e2d39f79SChristopher Bate }; 880e2d39f79SChristopher Bate addSourceMaterialization(unrealizedCastConverter); 881e2d39f79SChristopher Bate addTargetMaterialization(unrealizedCastConverter); 882bcfdb3e4SMatthias Springer } 883bcfdb3e4SMatthias Springer }; 884bcfdb3e4SMatthias Springer } // namespace 885bcfdb3e4SMatthias Springer 886bcfdb3e4SMatthias Springer std::unique_ptr<::mlir::TypeConverter> 887bcfdb3e4SMatthias Springer mlir::test::TestTypeConverterOp::getTypeConverter() { 888bcfdb3e4SMatthias Springer return std::make_unique<TestTypeConverter>(); 889bcfdb3e4SMatthias Springer } 890bcfdb3e4SMatthias Springer 891bcfdb3e4SMatthias Springer namespace { 892d064c480SAlex Zinenko /// Test extension of the Transform dialect. Registers additional ops and 893d064c480SAlex Zinenko /// declares PDL as dependent dialect since the additional ops are using PDL 894d064c480SAlex Zinenko /// types for operands and results. 895d064c480SAlex Zinenko class TestTransformDialectExtension 896d064c480SAlex Zinenko : public transform::TransformDialectExtension< 897d064c480SAlex Zinenko TestTransformDialectExtension> { 898d064c480SAlex Zinenko public: 89984cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension) 90084cc1865SNikhil Kalra 901333ee218SAlex Zinenko using Base::Base; 902333ee218SAlex Zinenko 903333ee218SAlex Zinenko void init() { 904d064c480SAlex Zinenko declareDependentDialect<pdl::PDLDialect>(); 905d064c480SAlex Zinenko registerTransformOps<TestTransformOp, 90630f22429SAlex Zinenko TestTransformUnrestrictedOpNoInterface, 907d064c480SAlex Zinenko #define GET_OP_LIST 908d064c480SAlex Zinenko #include "TestTransformDialectExtension.cpp.inc" 909d064c480SAlex Zinenko >(); 910bba85ebdSAlex Zinenko registerTypes< 911bba85ebdSAlex Zinenko #define GET_TYPEDEF_LIST 912bba85ebdSAlex Zinenko #include "TestTransformDialectExtensionTypes.cpp.inc" 913bba85ebdSAlex Zinenko >(); 91494d608d4SAlex Zinenko 9158ec28af8SMatthias Gehre auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &, 91694d608d4SAlex Zinenko ArrayRef<PDLValue> pdlValues) { 91794d608d4SAlex Zinenko for (const PDLValue &pdlValue : pdlValues) { 91894d608d4SAlex Zinenko if (Operation *op = pdlValue.dyn_cast<Operation *>()) { 91994d608d4SAlex Zinenko op->emitWarning() << "from PDL constraint"; 92094d608d4SAlex Zinenko } 92194d608d4SAlex Zinenko } 92294d608d4SAlex Zinenko return success(); 92394d608d4SAlex Zinenko }; 92494d608d4SAlex Zinenko 92594d608d4SAlex Zinenko addDialectDataInitializer<transform::PDLMatchHooks>( 92694d608d4SAlex Zinenko [&](transform::PDLMatchHooks &hooks) { 92794d608d4SAlex Zinenko llvm::StringMap<PDLConstraintFunction> constraints; 92894d608d4SAlex Zinenko constraints.try_emplace("verbose_constraint", verboseConstraint); 92994d608d4SAlex Zinenko hooks.mergeInPDLMatchHooks(std::move(constraints)); 93094d608d4SAlex Zinenko }); 931d064c480SAlex Zinenko } 932d064c480SAlex Zinenko }; 933d064c480SAlex Zinenko } // namespace 934d064c480SAlex Zinenko 935bba85ebdSAlex Zinenko // These are automatically generated by ODS but are not used as the Transform 936bba85ebdSAlex Zinenko // dialect uses a different dispatch mechanism to support dialect extensions. 937bba85ebdSAlex Zinenko LLVM_ATTRIBUTE_UNUSED static OptionalParseResult 938bba85ebdSAlex Zinenko generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); 939bba85ebdSAlex Zinenko LLVM_ATTRIBUTE_UNUSED static LogicalResult 940bba85ebdSAlex Zinenko generatedTypePrinter(Type def, AsmPrinter &printer); 941bba85ebdSAlex Zinenko 942bba85ebdSAlex Zinenko #define GET_TYPEDEF_CLASSES 943bba85ebdSAlex Zinenko #include "TestTransformDialectExtensionTypes.cpp.inc" 944bba85ebdSAlex Zinenko 945ed02fa81SAlex Zinenko #define GET_OP_CLASSES 946ed02fa81SAlex Zinenko #include "TestTransformDialectExtension.cpp.inc" 947ed02fa81SAlex Zinenko 948d064c480SAlex Zinenko void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { 949d064c480SAlex Zinenko registry.addExtensions<TestTransformDialectExtension>(); 950d064c480SAlex Zinenko } 951