xref: /llvm-project/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (revision f18c3e4e7335df282c468b6dff3d29be1822a96d)
1d064c480SAlex Zinenko //===- TestTransformDialectExtension.cpp ----------------------------------===//
2d064c480SAlex Zinenko //
3d064c480SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d064c480SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5d064c480SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d064c480SAlex Zinenko //
7d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
8d064c480SAlex Zinenko //
9d064c480SAlex Zinenko // This file defines an extension of the MLIR Transform dialect for testing
10d064c480SAlex Zinenko // purposes.
11d064c480SAlex Zinenko //
12d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
13d064c480SAlex Zinenko 
14d064c480SAlex Zinenko #include "TestTransformDialectExtension.h"
156c57b0deSAlex Zinenko #include "TestTransformStateExtension.h"
16d064c480SAlex Zinenko #include "mlir/Dialect/PDL/IR/PDL.h"
17d064c480SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h"
180e37ef08SMatthias Springer #include "mlir/Dialect/Transform/IR/TransformOps.h"
195a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
2094d608d4SAlex Zinenko #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
21d064c480SAlex Zinenko #include "mlir/IR/OpImplementation.h"
2294d608d4SAlex Zinenko #include "mlir/IR/PatternMatch.h"
23ed02fa81SAlex Zinenko #include "llvm/ADT/STLExtras.h"
24bba85ebdSAlex Zinenko #include "llvm/ADT/TypeSwitch.h"
25bba85ebdSAlex Zinenko #include "llvm/Support/Compiler.h"
26ed02fa81SAlex Zinenko #include "llvm/Support/raw_ostream.h"
27d064c480SAlex Zinenko 
28d064c480SAlex Zinenko using namespace mlir;
29d064c480SAlex Zinenko 
30d064c480SAlex Zinenko namespace {
31d064c480SAlex Zinenko /// Simple transform op defined outside of the dialect. Just emits a remark when
3230f22429SAlex Zinenko /// applied. This op is defined in C++ to test that C++ definitions also work
3330f22429SAlex Zinenko /// for op injection into the Transform dialect.
34d064c480SAlex Zinenko class TestTransformOp
3540a8bd63SAlex Zinenko     : public Op<TestTransformOp, transform::TransformOpInterface::Trait,
3640a8bd63SAlex Zinenko                 MemoryEffectOpInterface::Trait> {
37d064c480SAlex Zinenko public:
38d064c480SAlex Zinenko   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
39d064c480SAlex Zinenko 
40d064c480SAlex Zinenko   using Op::Op;
41d064c480SAlex Zinenko 
42d064c480SAlex Zinenko   static ArrayRef<StringRef> getAttributeNames() { return {}; }
43d064c480SAlex Zinenko 
44d064c480SAlex Zinenko   static constexpr llvm::StringLiteral getOperationName() {
45d064c480SAlex Zinenko     return llvm::StringLiteral("transform.test_transform_op");
46d064c480SAlex Zinenko   }
47d064c480SAlex Zinenko 
48c63d2b2cSMatthias Springer   DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
49c63d2b2cSMatthias Springer                                     transform::TransformResults &results,
50d064c480SAlex Zinenko                                     transform::TransformState &state) {
510eb403adSAlex Zinenko     InFlightDiagnostic remark = emitRemark() << "applying transformation";
520eb403adSAlex Zinenko     if (Attribute message = getMessage())
530eb403adSAlex Zinenko       remark << " " << message;
540eb403adSAlex Zinenko 
551d45282aSAlex Zinenko     return DiagnosedSilenceableFailure::success();
56d064c480SAlex Zinenko   }
57d064c480SAlex Zinenko 
58830b9b07SMehdi Amini   Attribute getMessage() {
59830b9b07SMehdi Amini     return getOperation()->getDiscardableAttr("message");
60830b9b07SMehdi Amini   }
610eb403adSAlex Zinenko 
62d064c480SAlex Zinenko   static ParseResult parse(OpAsmParser &parser, OperationState &state) {
630eb403adSAlex Zinenko     StringAttr message;
640eb403adSAlex Zinenko     OptionalParseResult result = parser.parseOptionalAttribute(message);
659750648cSKazu Hirata     if (!result.has_value())
66d064c480SAlex Zinenko       return success();
670eb403adSAlex Zinenko 
68c8e6ebd7SKazu Hirata     if (result.value().succeeded())
690eb403adSAlex Zinenko       state.addAttribute("message", message);
70c8e6ebd7SKazu Hirata     return result.value();
71d064c480SAlex Zinenko   }
72d064c480SAlex Zinenko 
730eb403adSAlex Zinenko   void print(OpAsmPrinter &printer) {
740eb403adSAlex Zinenko     if (getMessage())
750eb403adSAlex Zinenko       printer << " " << getMessage();
760eb403adSAlex Zinenko   }
7740a8bd63SAlex Zinenko 
7840a8bd63SAlex Zinenko   // No side effects.
7940a8bd63SAlex Zinenko   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
80d064c480SAlex Zinenko };
8130f22429SAlex Zinenko 
8230f22429SAlex Zinenko /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
8330f22429SAlex Zinenko /// in cases where it is attached to ops that do not comply with the trait
8430f22429SAlex Zinenko /// requirements. This op cannot be defined in ODS because ODS generates strict
8530f22429SAlex Zinenko /// verifiers that overalp with those in the trait and run earlier.
8630f22429SAlex Zinenko class TestTransformUnrestrictedOpNoInterface
8730f22429SAlex Zinenko     : public Op<TestTransformUnrestrictedOpNoInterface,
8830f22429SAlex Zinenko                 transform::PossibleTopLevelTransformOpTrait,
8940a8bd63SAlex Zinenko                 transform::TransformOpInterface::Trait,
9040a8bd63SAlex Zinenko                 MemoryEffectOpInterface::Trait> {
9130f22429SAlex Zinenko public:
9230f22429SAlex Zinenko   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
9330f22429SAlex Zinenko       TestTransformUnrestrictedOpNoInterface)
9430f22429SAlex Zinenko 
9530f22429SAlex Zinenko   using Op::Op;
9630f22429SAlex Zinenko 
9730f22429SAlex Zinenko   static ArrayRef<StringRef> getAttributeNames() { return {}; }
9830f22429SAlex Zinenko 
9930f22429SAlex Zinenko   static constexpr llvm::StringLiteral getOperationName() {
10030f22429SAlex Zinenko     return llvm::StringLiteral(
10130f22429SAlex Zinenko         "transform.test_transform_unrestricted_op_no_interface");
10230f22429SAlex Zinenko   }
10330f22429SAlex Zinenko 
104c63d2b2cSMatthias Springer   DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
105c63d2b2cSMatthias Springer                                     transform::TransformResults &results,
10630f22429SAlex Zinenko                                     transform::TransformState &state) {
1071d45282aSAlex Zinenko     return DiagnosedSilenceableFailure::success();
10830f22429SAlex Zinenko   }
10940a8bd63SAlex Zinenko 
11040a8bd63SAlex Zinenko   // No side effects.
11140a8bd63SAlex Zinenko   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
11230f22429SAlex Zinenko };
113d064c480SAlex Zinenko } // namespace
114d064c480SAlex Zinenko 
1151d45282aSAlex Zinenko DiagnosedSilenceableFailure
116a7026288SAlex Zinenko mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
117c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
118d064c480SAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
119d064c480SAlex Zinenko   if (getOperation()->getNumOperands() != 0) {
1200e37ef08SMatthias Springer     results.set(cast<OpResult>(getResult()),
1210e37ef08SMatthias Springer                 {getOperation()->getOperand(0).getDefiningOp()});
122d064c480SAlex Zinenko   } else {
1230e37ef08SMatthias Springer     results.set(cast<OpResult>(getResult()), {getOperation()});
124d064c480SAlex Zinenko   }
1251d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
126d064c480SAlex Zinenko }
127d064c480SAlex Zinenko 
128a7026288SAlex Zinenko void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects(
1290242b962SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1300242b962SAlex Zinenko   if (getOperand())
1312c1ae801Sdonald chen     transform::onlyReadsHandle(getOperandMutable(), effects);
1322c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
1330242b962SAlex Zinenko }
1340242b962SAlex Zinenko 
135a7026288SAlex Zinenko DiagnosedSilenceableFailure
136a7026288SAlex Zinenko mlir::test::TestProduceValueHandleToSelfOperand::apply(
137c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
138a7026288SAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
139085075a5SMatthias Springer   results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
140a7026288SAlex Zinenko   return DiagnosedSilenceableFailure::success();
141a7026288SAlex Zinenko }
142a7026288SAlex Zinenko 
143a7026288SAlex Zinenko void mlir::test::TestProduceValueHandleToSelfOperand::getEffects(
144a7026288SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1452c1ae801Sdonald chen   transform::onlyReadsHandle(getInMutable(), effects);
1462c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
147a7026288SAlex Zinenko   transform::onlyReadsPayload(effects);
148a7026288SAlex Zinenko }
149a7026288SAlex Zinenko 
150a7026288SAlex Zinenko DiagnosedSilenceableFailure
151a7026288SAlex Zinenko mlir::test::TestProduceValueHandleToResult::applyToOne(
152c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
153c63d2b2cSMatthias Springer     transform::ApplyToEachResultList &results,
154a7026288SAlex Zinenko     transform::TransformState &state) {
155a7026288SAlex Zinenko   if (target->getNumResults() <= getNumber())
156a7026288SAlex Zinenko     return emitSilenceableError() << "payload has no result #" << getNumber();
157a7026288SAlex Zinenko   results.push_back(target->getResult(getNumber()));
158a7026288SAlex Zinenko   return DiagnosedSilenceableFailure::success();
159a7026288SAlex Zinenko }
160a7026288SAlex Zinenko 
161a7026288SAlex Zinenko void mlir::test::TestProduceValueHandleToResult::getEffects(
162a7026288SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1632c1ae801Sdonald chen   transform::onlyReadsHandle(getInMutable(), effects);
1642c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
165a7026288SAlex Zinenko   transform::onlyReadsPayload(effects);
166a7026288SAlex Zinenko }
167a7026288SAlex Zinenko 
168a7026288SAlex Zinenko DiagnosedSilenceableFailure
169a7026288SAlex Zinenko mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne(
170c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
171c63d2b2cSMatthias Springer     transform::ApplyToEachResultList &results,
172a7026288SAlex Zinenko     transform::TransformState &state) {
173a7026288SAlex Zinenko   if (!target->getBlock())
174a7026288SAlex Zinenko     return emitSilenceableError() << "payload has no parent block";
175a7026288SAlex Zinenko   if (target->getBlock()->getNumArguments() <= getNumber())
176a7026288SAlex Zinenko     return emitSilenceableError()
177a7026288SAlex Zinenko            << "parent of the payload has no argument #" << getNumber();
178a7026288SAlex Zinenko   results.push_back(target->getBlock()->getArgument(getNumber()));
179a7026288SAlex Zinenko   return DiagnosedSilenceableFailure::success();
180a7026288SAlex Zinenko }
181a7026288SAlex Zinenko 
182a7026288SAlex Zinenko void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects(
183a7026288SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1842c1ae801Sdonald chen   transform::onlyReadsHandle(getInMutable(), effects);
1852c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
186a7026288SAlex Zinenko   transform::onlyReadsPayload(effects);
187d064c480SAlex Zinenko }
188d064c480SAlex Zinenko 
1891d9a1139SAlex Zinenko bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() {
1901d9a1139SAlex Zinenko   return getAllowRepeatedHandles();
1911d9a1139SAlex Zinenko }
1921d9a1139SAlex Zinenko 
1931d45282aSAlex Zinenko DiagnosedSilenceableFailure
194c63d2b2cSMatthias Springer mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter,
195c63d2b2cSMatthias Springer                                       transform::TransformResults &results,
1966403e1b1SAlex Zinenko                                       transform::TransformState &state) {
1971d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
1986403e1b1SAlex Zinenko }
1996403e1b1SAlex Zinenko 
2000242b962SAlex Zinenko void mlir::test::TestConsumeOperand::getEffects(
2010242b962SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2022c1ae801Sdonald chen   transform::consumesHandle(getOperation()->getOpOperands(), effects);
2030242b962SAlex Zinenko   if (getSecondOperand())
2042c1ae801Sdonald chen     transform::consumesHandle(getSecondOperandMutable(), effects);
2050242b962SAlex Zinenko   transform::modifiesPayload(effects);
2060242b962SAlex Zinenko }
2070242b962SAlex Zinenko 
208a7026288SAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply(
209c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
210d064c480SAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
2110e37ef08SMatthias Springer   auto payload = state.getPayloadOps(getOperand());
2120e37ef08SMatthias Springer   assert(llvm::hasSingleElement(payload) && "expected a single target op");
2130e37ef08SMatthias Springer   if ((*payload.begin())->getName().getStringRef() != getOpKind()) {
2141d45282aSAlex Zinenko     return emitSilenceableError()
215a7026288SAlex Zinenko            << "op expected the operand to be associated a payload op of kind "
2160e37ef08SMatthias Springer            << getOpKind() << " got "
2170e37ef08SMatthias Springer            << (*payload.begin())->getName().getStringRef();
218d064c480SAlex Zinenko   }
219d064c480SAlex Zinenko 
220d064c480SAlex Zinenko   emitRemark() << "succeeded";
2211d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
222d064c480SAlex Zinenko }
223d064c480SAlex Zinenko 
224a7026288SAlex Zinenko void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects(
2250242b962SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2262c1ae801Sdonald chen   transform::consumesHandle(getOperation()->getOpOperands(), effects);
2270242b962SAlex Zinenko   transform::modifiesPayload(effects);
2280242b962SAlex Zinenko }
2290242b962SAlex Zinenko 
23063c9d2b1SAlex Zinenko DiagnosedSilenceableFailure
23163c9d2b1SAlex Zinenko mlir::test::TestSucceedIfOperandOfOpKind::matchOperation(
23263c9d2b1SAlex Zinenko     Operation *op, transform::TransformResults &results,
23363c9d2b1SAlex Zinenko     transform::TransformState &state) {
23463c9d2b1SAlex Zinenko   if (op->getName().getStringRef() != getOpKind()) {
23563c9d2b1SAlex Zinenko     return emitSilenceableError()
23663c9d2b1SAlex Zinenko            << "op expected the operand to be associated with a payload op of "
23763c9d2b1SAlex Zinenko               "kind "
23863c9d2b1SAlex Zinenko            << getOpKind() << " got " << op->getName().getStringRef();
23963c9d2b1SAlex Zinenko   }
24063c9d2b1SAlex Zinenko   return DiagnosedSilenceableFailure::success();
24163c9d2b1SAlex Zinenko }
24263c9d2b1SAlex Zinenko 
24363c9d2b1SAlex Zinenko void mlir::test::TestSucceedIfOperandOfOpKind::getEffects(
24463c9d2b1SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2452c1ae801Sdonald chen   transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
24663c9d2b1SAlex Zinenko   transform::onlyReadsPayload(effects);
24763c9d2b1SAlex Zinenko }
24863c9d2b1SAlex Zinenko 
249c63d2b2cSMatthias Springer DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply(
250c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
251c63d2b2cSMatthias Springer     transform::TransformResults &results, transform::TransformState &state) {
2526c57b0deSAlex Zinenko   state.addExtension<TestTransformStateExtension>(getMessageAttr());
2531d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
2546c57b0deSAlex Zinenko }
2556c57b0deSAlex Zinenko 
2561d45282aSAlex Zinenko DiagnosedSilenceableFailure
2571d45282aSAlex Zinenko mlir::test::TestCheckIfTestExtensionPresentOp::apply(
258c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
2596c57b0deSAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
2606c57b0deSAlex Zinenko   auto *extension = state.getExtension<TestTransformStateExtension>();
2616c57b0deSAlex Zinenko   if (!extension) {
2626c57b0deSAlex Zinenko     emitRemark() << "extension absent";
2631d45282aSAlex Zinenko     return DiagnosedSilenceableFailure::success();
2646c57b0deSAlex Zinenko   }
2656c57b0deSAlex Zinenko 
2666c57b0deSAlex Zinenko   InFlightDiagnostic diag = emitRemark()
2676c57b0deSAlex Zinenko                             << "extension present, " << extension->getMessage();
2686c57b0deSAlex Zinenko   for (Operation *payload : state.getPayloadOps(getOperand())) {
2696c57b0deSAlex Zinenko     diag.attachNote(payload->getLoc()) << "associated payload op";
2703dfea727SAlex Zinenko #ifndef NDEBUG
2713dfea727SAlex Zinenko     SmallVector<Value> handles;
2723dfea727SAlex Zinenko     assert(succeeded(state.getHandlesForPayloadOp(payload, handles)));
2733dfea727SAlex Zinenko     assert(llvm::is_contained(handles, getOperand()) &&
2746c57b0deSAlex Zinenko            "inconsistent mapping between transform IR handles and payload IR "
2756c57b0deSAlex Zinenko            "operations");
2763dfea727SAlex Zinenko #endif // NDEBUG
2776c57b0deSAlex Zinenko   }
2786c57b0deSAlex Zinenko 
2791d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
2806c57b0deSAlex Zinenko }
2816c57b0deSAlex Zinenko 
2820242b962SAlex Zinenko void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects(
2830242b962SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2842c1ae801Sdonald chen   transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
2850242b962SAlex Zinenko   transform::onlyReadsPayload(effects);
2860242b962SAlex Zinenko }
2870242b962SAlex Zinenko 
2881d45282aSAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
289c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
2906c57b0deSAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
2916c57b0deSAlex Zinenko   auto *extension = state.getExtension<TestTransformStateExtension>();
292b0bf7fffSAlex Zinenko   if (!extension)
293b0bf7fffSAlex Zinenko     return emitDefiniteFailure("TestTransformStateExtension missing");
2946c57b0deSAlex Zinenko 
2950e37ef08SMatthias Springer   if (failed(extension->updateMapping(
2960e37ef08SMatthias Springer           *state.getPayloadOps(getOperand()).begin(), getOperation())))
2971d45282aSAlex Zinenko     return DiagnosedSilenceableFailure::definiteFailure();
298d9db5a59SAlex Zinenko   if (getNumResults() > 0)
2990e37ef08SMatthias Springer     results.set(cast<OpResult>(getResult(0)), {getOperation()});
3001d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
301e3890b7fSAlex Zinenko }
302e3890b7fSAlex Zinenko 
3030242b962SAlex Zinenko void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects(
3040242b962SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3052c1ae801Sdonald chen   transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
3062c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
3070242b962SAlex Zinenko   transform::onlyReadsPayload(effects);
3080242b962SAlex Zinenko }
3090242b962SAlex Zinenko 
3101d45282aSAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
311c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
3126c57b0deSAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
3136c57b0deSAlex Zinenko   state.removeExtension<TestTransformStateExtension>();
3141d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
3156c57b0deSAlex Zinenko }
316a299539aSMatthias Springer 
317c63d2b2cSMatthias Springer DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(
318c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
319c63d2b2cSMatthias Springer     transform::TransformResults &results, transform::TransformState &state) {
3200e37ef08SMatthias Springer   auto payloadOps = state.getPayloadOps(getTarget());
321a299539aSMatthias Springer   auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
322c1fa60b4STres Popp   results.set(llvm::cast<OpResult>(getResult()), reversedOps);
323a299539aSMatthias Springer   return DiagnosedSilenceableFailure::success();
324a299539aSMatthias Springer }
325a299539aSMatthias Springer 
3261d45282aSAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply(
327c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
32873c3dff1SAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
3291d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
33073c3dff1SAlex Zinenko }
33173c3dff1SAlex Zinenko 
33273c3dff1SAlex Zinenko void mlir::test::TestTransformOpWithRegions::getEffects(
33373c3dff1SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
33473c3dff1SAlex Zinenko 
3351d45282aSAlex Zinenko DiagnosedSilenceableFailure
336e3890b7fSAlex Zinenko mlir::test::TestBranchingTransformOpTerminator::apply(
337c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
33873c3dff1SAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
3391d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
34073c3dff1SAlex Zinenko }
34173c3dff1SAlex Zinenko 
34273c3dff1SAlex Zinenko void mlir::test::TestBranchingTransformOpTerminator::getEffects(
34373c3dff1SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
3446c57b0deSAlex Zinenko 
3451d45282aSAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
346c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
347e3890b7fSAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
348e3890b7fSAlex Zinenko   emitRemark() << getRemark();
349e3890b7fSAlex Zinenko   for (Operation *op : state.getPayloadOps(getTarget()))
3504f63252dSMatthias Springer     rewriter.eraseOp(op);
351e3890b7fSAlex Zinenko 
352e3890b7fSAlex Zinenko   if (getFailAfterErase())
353a60ed954SAlex Zinenko     return emitSilenceableError() << "silenceable error";
3541d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
355e3890b7fSAlex Zinenko }
356e3890b7fSAlex Zinenko 
35707fef178SMatthias Springer void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects(
35807fef178SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3592c1ae801Sdonald chen   transform::consumesHandle(getTargetMutable(), effects);
36007fef178SMatthias Springer   transform::modifiesPayload(effects);
36107fef178SMatthias Springer }
36207fef178SMatthias Springer 
36352307109SNicolas Vasilache DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
364c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
365c63d2b2cSMatthias Springer     transform::ApplyToEachResultList &results,
36652307109SNicolas Vasilache     transform::TransformState &state) {
36752307109SNicolas Vasilache   OperationState opState(target->getLoc(), "foo");
36852307109SNicolas Vasilache   results.push_back(OpBuilder(target).create(opState));
36952307109SNicolas Vasilache   return DiagnosedSilenceableFailure::success();
370f439b319SNicolas Vasilache }
371f439b319SNicolas Vasilache 
37252307109SNicolas Vasilache DiagnosedSilenceableFailure
3734c7225d1SNicolas Vasilache mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
374c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
375c63d2b2cSMatthias Springer     transform::ApplyToEachResultList &results,
37652307109SNicolas Vasilache     transform::TransformState &state) {
3774c7225d1SNicolas Vasilache   static int count = 0;
37852307109SNicolas Vasilache   if (count++ == 0) {
37952307109SNicolas Vasilache     OperationState opState(target->getLoc(), "foo");
38052307109SNicolas Vasilache     results.push_back(OpBuilder(target).create(opState));
38152307109SNicolas Vasilache   }
38252307109SNicolas Vasilache   return DiagnosedSilenceableFailure::success();
3834c7225d1SNicolas Vasilache }
3844c7225d1SNicolas Vasilache 
38552307109SNicolas Vasilache DiagnosedSilenceableFailure
3864c7225d1SNicolas Vasilache mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
387c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
388c63d2b2cSMatthias Springer     transform::ApplyToEachResultList &results,
38952307109SNicolas Vasilache     transform::TransformState &state) {
39052307109SNicolas Vasilache   OperationState opState(target->getLoc(), "foo");
39152307109SNicolas Vasilache   results.push_back(OpBuilder(target).create(opState));
39252307109SNicolas Vasilache   results.push_back(OpBuilder(target).create(opState));
39352307109SNicolas Vasilache   return DiagnosedSilenceableFailure::success();
39452307109SNicolas Vasilache }
39552307109SNicolas Vasilache 
39652307109SNicolas Vasilache DiagnosedSilenceableFailure
39752307109SNicolas Vasilache mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
398c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
399c63d2b2cSMatthias Springer     transform::ApplyToEachResultList &results,
40052307109SNicolas Vasilache     transform::TransformState &state) {
40152307109SNicolas Vasilache   OperationState opState(target->getLoc(), "foo");
40252307109SNicolas Vasilache   results.push_back(nullptr);
40352307109SNicolas Vasilache   results.push_back(OpBuilder(target).create(opState));
40452307109SNicolas Vasilache   return DiagnosedSilenceableFailure::success();
4054c7225d1SNicolas Vasilache }
4064c7225d1SNicolas Vasilache 
40769c8319eSNicolas Vasilache DiagnosedSilenceableFailure
40869c8319eSNicolas Vasilache mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
409c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
410c63d2b2cSMatthias Springer     transform::ApplyToEachResultList &results,
41169c8319eSNicolas Vasilache     transform::TransformState &state) {
41269c8319eSNicolas Vasilache   if (target->hasAttr("target_me"))
41369c8319eSNicolas Vasilache     return DiagnosedSilenceableFailure::success();
41469c8319eSNicolas Vasilache   return emitDefaultSilenceableFailure(target);
41569c8319eSNicolas Vasilache }
41669c8319eSNicolas Vasilache 
41700d1a1a2SAlex Zinenko DiagnosedSilenceableFailure
418c63d2b2cSMatthias Springer mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
419c63d2b2cSMatthias Springer                                      transform::TransformResults &results,
4203dfea727SAlex Zinenko                                      transform::TransformState &state) {
421c1fa60b4STres Popp   results.set(llvm::cast<OpResult>(getCopy()),
422c1fa60b4STres Popp               state.getPayloadOps(getHandle()));
4233dfea727SAlex Zinenko   return DiagnosedSilenceableFailure::success();
4243dfea727SAlex Zinenko }
4253dfea727SAlex Zinenko 
4260242b962SAlex Zinenko void mlir::test::TestCopyPayloadOp::getEffects(
4270242b962SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4282c1ae801Sdonald chen   transform::onlyReadsHandle(getHandleMutable(), effects);
4292c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
4300242b962SAlex Zinenko   transform::onlyReadsPayload(effects);
4310242b962SAlex Zinenko }
4320242b962SAlex Zinenko 
433bba85ebdSAlex Zinenko DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
434bba85ebdSAlex Zinenko     Location loc, ArrayRef<Operation *> payload) const {
435bba85ebdSAlex Zinenko   if (payload.empty())
436bba85ebdSAlex Zinenko     return DiagnosedSilenceableFailure::success();
437bba85ebdSAlex Zinenko 
438bba85ebdSAlex Zinenko   for (Operation *op : payload) {
439bba85ebdSAlex Zinenko     if (op->getName().getDialectNamespace() != "test") {
440ed02fa81SAlex Zinenko       return emitSilenceableError(loc) << "expected the payload operation to "
441ed02fa81SAlex Zinenko                                           "belong to the 'test' dialect";
442bba85ebdSAlex Zinenko     }
443bba85ebdSAlex Zinenko   }
444bba85ebdSAlex Zinenko 
445bba85ebdSAlex Zinenko   return DiagnosedSilenceableFailure::success();
446bba85ebdSAlex Zinenko }
447bba85ebdSAlex Zinenko 
448ed02fa81SAlex Zinenko DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
449ed02fa81SAlex Zinenko     Location loc, ArrayRef<Attribute> payload) const {
450ed02fa81SAlex Zinenko   for (Attribute attr : payload) {
451c1fa60b4STres Popp     auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
452ed02fa81SAlex Zinenko     if (integerAttr && integerAttr.getType().isSignlessInteger(32))
453ed02fa81SAlex Zinenko       continue;
454ed02fa81SAlex Zinenko     return emitSilenceableError(loc)
455ed02fa81SAlex Zinenko            << "expected the parameter to be a i32 integer attribute";
456ed02fa81SAlex Zinenko   }
457ed02fa81SAlex Zinenko 
458ed02fa81SAlex Zinenko   return DiagnosedSilenceableFailure::success();
459ed02fa81SAlex Zinenko }
460ed02fa81SAlex Zinenko 
461d8cab3f4SNicolas Vasilache void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
462d8cab3f4SNicolas Vasilache     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4632c1ae801Sdonald chen   transform::onlyReadsHandle(getTargetMutable(), effects);
464d8cab3f4SNicolas Vasilache }
465d8cab3f4SNicolas Vasilache 
466d8cab3f4SNicolas Vasilache DiagnosedSilenceableFailure
467d8cab3f4SNicolas Vasilache mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
468c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
469d8cab3f4SNicolas Vasilache     transform::TransformResults &results, transform::TransformState &state) {
470d8cab3f4SNicolas Vasilache   int64_t count = 0;
471d8cab3f4SNicolas Vasilache   for (Operation *op : state.getPayloadOps(getTarget())) {
472d8cab3f4SNicolas Vasilache     op->walk([&](Operation *nested) {
473d8cab3f4SNicolas Vasilache       SmallVector<Value> handles;
474d8cab3f4SNicolas Vasilache       (void)state.getHandlesForPayloadOp(nested, handles);
475d8cab3f4SNicolas Vasilache       count += handles.size();
476d8cab3f4SNicolas Vasilache     });
477d8cab3f4SNicolas Vasilache   }
478d8cab3f4SNicolas Vasilache   emitRemark() << count << " handles nested under";
479d8cab3f4SNicolas Vasilache   return DiagnosedSilenceableFailure::success();
480d8cab3f4SNicolas Vasilache }
481d8cab3f4SNicolas Vasilache 
482ed02fa81SAlex Zinenko DiagnosedSilenceableFailure
483c63d2b2cSMatthias Springer mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter,
484c63d2b2cSMatthias Springer                                     transform::TransformResults &results,
485ed02fa81SAlex Zinenko                                     transform::TransformState &state) {
486ed02fa81SAlex Zinenko   SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0);
487ed02fa81SAlex Zinenko   if (Value param = getParam()) {
488ed02fa81SAlex Zinenko     values = llvm::to_vector(
489ed02fa81SAlex Zinenko         llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
490c1fa60b4STres Popp           return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue(
491ed02fa81SAlex Zinenko               UINT32_MAX);
492ed02fa81SAlex Zinenko         }));
493ed02fa81SAlex Zinenko   }
494ed02fa81SAlex Zinenko 
495ed02fa81SAlex Zinenko   Builder builder(getContext());
496ed02fa81SAlex Zinenko   SmallVector<Attribute> result = llvm::to_vector(
497ed02fa81SAlex Zinenko       llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
498ed02fa81SAlex Zinenko         return builder.getI32IntegerAttr(value + getAddendum());
499ed02fa81SAlex Zinenko       }));
500c1fa60b4STres Popp   results.setParams(llvm::cast<OpResult>(getResult()), result);
501ed02fa81SAlex Zinenko   return DiagnosedSilenceableFailure::success();
502ed02fa81SAlex Zinenko }
503ed02fa81SAlex Zinenko 
504ed02fa81SAlex Zinenko DiagnosedSilenceableFailure
505ed02fa81SAlex Zinenko mlir::test::TestProduceParamWithNumberOfTestOps::apply(
506c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
507ed02fa81SAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
508ed02fa81SAlex Zinenko   Builder builder(getContext());
509ed02fa81SAlex Zinenko   SmallVector<Attribute> result = llvm::to_vector(
510ed02fa81SAlex Zinenko       llvm::map_range(state.getPayloadOps(getHandle()),
511ed02fa81SAlex Zinenko                       [&builder](Operation *payload) -> Attribute {
512ed02fa81SAlex Zinenko                         int32_t count = 0;
513ed02fa81SAlex Zinenko                         payload->walk([&count](Operation *op) {
514ed02fa81SAlex Zinenko                           if (op->getName().getDialectNamespace() == "test")
515ed02fa81SAlex Zinenko                             ++count;
516ed02fa81SAlex Zinenko                         });
517ed02fa81SAlex Zinenko                         return builder.getI32IntegerAttr(count);
518ed02fa81SAlex Zinenko                       }));
519c1fa60b4STres Popp   results.setParams(llvm::cast<OpResult>(getResult()), result);
520ed02fa81SAlex Zinenko   return DiagnosedSilenceableFailure::success();
521ed02fa81SAlex Zinenko }
522ed02fa81SAlex Zinenko 
523ed02fa81SAlex Zinenko DiagnosedSilenceableFailure
524dd81c6b8SAlex Zinenko mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter,
525dd81c6b8SAlex Zinenko                                       transform::TransformResults &results,
526dd81c6b8SAlex Zinenko                                       transform::TransformState &state) {
527dd81c6b8SAlex Zinenko   results.setParams(llvm::cast<OpResult>(getResult()), getAttr());
528ed02fa81SAlex Zinenko   return DiagnosedSilenceableFailure::success();
529ed02fa81SAlex Zinenko }
530ed02fa81SAlex Zinenko 
5314b455a71SAlex Zinenko void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects(
5324b455a71SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
5332c1ae801Sdonald chen   transform::onlyReadsHandle(getInMutable(), effects);
5342c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
5354b455a71SAlex Zinenko }
5364b455a71SAlex Zinenko 
5374b455a71SAlex Zinenko DiagnosedSilenceableFailure
5384b455a71SAlex Zinenko mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
539c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
540c63d2b2cSMatthias Springer     ::transform::ApplyToEachResultList &results,
5414b455a71SAlex Zinenko     ::transform::TransformState &state) {
5424b455a71SAlex Zinenko   Builder builder(getContext());
5434b455a71SAlex Zinenko   if (getFirstResultIsParam()) {
5444b455a71SAlex Zinenko     results.push_back(builder.getI64IntegerAttr(0));
5454b455a71SAlex Zinenko   } else if (getFirstResultIsNull()) {
5464b455a71SAlex Zinenko     results.push_back(nullptr);
5474b455a71SAlex Zinenko   } else {
5480e37ef08SMatthias Springer     results.push_back(*state.getPayloadOps(getIn()).begin());
5494b455a71SAlex Zinenko   }
5504b455a71SAlex Zinenko 
5514b455a71SAlex Zinenko   if (getSecondResultIsHandle()) {
5520e37ef08SMatthias Springer     results.push_back(*state.getPayloadOps(getIn()).begin());
5534b455a71SAlex Zinenko   } else {
5544b455a71SAlex Zinenko     results.push_back(builder.getI64IntegerAttr(42));
5554b455a71SAlex Zinenko   }
5564b455a71SAlex Zinenko 
5574b455a71SAlex Zinenko   return DiagnosedSilenceableFailure::success();
5584b455a71SAlex Zinenko }
5594b455a71SAlex Zinenko 
560984c2c8cSAlex Zinenko void mlir::test::TestProduceNullPayloadOp::getEffects(
561984c2c8cSAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
5622c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
563984c2c8cSAlex Zinenko }
564984c2c8cSAlex Zinenko 
565984c2c8cSAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
566c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
567984c2c8cSAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
568984c2c8cSAlex Zinenko   SmallVector<Operation *, 1> null({nullptr});
569c1fa60b4STres Popp   results.set(llvm::cast<OpResult>(getOut()), null);
570984c2c8cSAlex Zinenko   return DiagnosedSilenceableFailure::success();
571984c2c8cSAlex Zinenko }
572984c2c8cSAlex Zinenko 
57344f6e862SAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply(
574c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
57544f6e862SAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
57644f6e862SAlex Zinenko   results.set(cast<OpResult>(getOut()), {});
57744f6e862SAlex Zinenko   return DiagnosedSilenceableFailure::success();
57844f6e862SAlex Zinenko }
57944f6e862SAlex Zinenko 
580984c2c8cSAlex Zinenko void mlir::test::TestProduceNullParamOp::getEffects(
58139489284SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
5822c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
58339489284SAlex Zinenko }
584984c2c8cSAlex Zinenko 
585c63d2b2cSMatthias Springer DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(
586c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
587c63d2b2cSMatthias Springer     transform::TransformResults &results, transform::TransformState &state) {
588c1fa60b4STres Popp   results.setParams(llvm::cast<OpResult>(getOut()), Attribute());
589984c2c8cSAlex Zinenko   return DiagnosedSilenceableFailure::success();
590984c2c8cSAlex Zinenko }
591984c2c8cSAlex Zinenko 
592a7026288SAlex Zinenko void mlir::test::TestProduceNullValueOp::getEffects(
593a7026288SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
5942c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
595a7026288SAlex Zinenko }
596a7026288SAlex Zinenko 
597c63d2b2cSMatthias Springer DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
598c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
599c63d2b2cSMatthias Springer     transform::TransformResults &results, transform::TransformState &state) {
600085075a5SMatthias Springer   results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
601a7026288SAlex Zinenko   return DiagnosedSilenceableFailure::success();
602a7026288SAlex Zinenko }
603a7026288SAlex Zinenko 
60439489284SAlex Zinenko void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
60539489284SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
60639489284SAlex Zinenko   if (getHasOperandEffect())
6072c1ae801Sdonald chen     transform::consumesHandle(getInMutable(), effects);
60839489284SAlex Zinenko 
6092c1ae801Sdonald chen   if (getHasResultEffect()) {
6102c1ae801Sdonald chen     transform::producesHandle(getOperation()->getOpResults(), effects);
6112c1ae801Sdonald chen   } else {
6122c1ae801Sdonald chen     effects.emplace_back(MemoryEffects::Read::get(),
6132c1ae801Sdonald chen                          llvm::cast<OpResult>(getOut()),
6142c1ae801Sdonald chen                          transform::TransformMappingResource::get());
6152c1ae801Sdonald chen   }
6160242b962SAlex Zinenko 
6170242b962SAlex Zinenko   if (getModifiesPayload())
6180242b962SAlex Zinenko     transform::modifiesPayload(effects);
61939489284SAlex Zinenko }
62039489284SAlex Zinenko 
62139489284SAlex Zinenko DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
622c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter,
62339489284SAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
624c1fa60b4STres Popp   results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn()));
62539489284SAlex Zinenko   return DiagnosedSilenceableFailure::success();
62639489284SAlex Zinenko }
62739489284SAlex Zinenko 
6280e37ef08SMatthias Springer void mlir::test::TestTrackedRewriteOp::getEffects(
6290e37ef08SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
6302c1ae801Sdonald chen   transform::onlyReadsHandle(getInMutable(), effects);
6310e37ef08SMatthias Springer   transform::modifiesPayload(effects);
6320e37ef08SMatthias Springer }
6330e37ef08SMatthias Springer 
6341b390f5eSMatthias Springer void mlir::test::TestDummyPayloadOp::getEffects(
6351b390f5eSMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
6362c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
6370e37ef08SMatthias Springer }
6380e37ef08SMatthias Springer 
6397dfcd4b7SMatthias Springer LogicalResult mlir::test::TestDummyPayloadOp::verify() {
6407dfcd4b7SMatthias Springer   if (getFailToVerify())
6417dfcd4b7SMatthias Springer     return emitOpError() << "fail_to_verify is set";
6427dfcd4b7SMatthias Springer   return success();
6437dfcd4b7SMatthias Springer }
6447dfcd4b7SMatthias Springer 
6450e37ef08SMatthias Springer DiagnosedSilenceableFailure
646c63d2b2cSMatthias Springer mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter,
647c63d2b2cSMatthias Springer                                         transform::TransformResults &results,
6480e37ef08SMatthias Springer                                         transform::TransformState &state) {
6490e37ef08SMatthias Springer   int64_t numIterations = 0;
6500e37ef08SMatthias Springer 
6510e37ef08SMatthias Springer   // `getPayloadOps` returns an iterator that skips ops that are erased in the
6520e37ef08SMatthias Springer   // loop body. Replacement ops are not enumerated.
6530e37ef08SMatthias Springer   for (Operation *op : state.getPayloadOps(getIn())) {
6540e37ef08SMatthias Springer     ++numIterations;
6551b390f5eSMatthias Springer     (void)op;
6560e37ef08SMatthias Springer 
6570e37ef08SMatthias Springer     // Erase all payload ops. The outer loop should have only one iteration.
6580e37ef08SMatthias Springer     for (Operation *op : state.getPayloadOps(getIn())) {
6591b390f5eSMatthias Springer       rewriter.setInsertionPoint(op);
6601b390f5eSMatthias Springer       if (op->hasAttr("erase_me")) {
6611b390f5eSMatthias Springer         rewriter.eraseOp(op);
6620e37ef08SMatthias Springer         continue;
6631b390f5eSMatthias Springer       }
6641b390f5eSMatthias Springer       if (!op->hasAttr("replace_me")) {
6650e37ef08SMatthias Springer         continue;
6661b390f5eSMatthias Springer       }
6671b390f5eSMatthias Springer 
6680e37ef08SMatthias Springer       SmallVector<NamedAttribute> attributes;
6691b390f5eSMatthias Springer       attributes.emplace_back(rewriter.getStringAttr("new_op"),
6701b390f5eSMatthias Springer                               rewriter.getUnitAttr());
6711b390f5eSMatthias Springer       OperationState opState(op->getLoc(), op->getName().getIdentifier(),
6720e37ef08SMatthias Springer                              /*operands=*/ValueRange(),
6730e37ef08SMatthias Springer                              /*types=*/op->getResultTypes(), attributes);
6740e37ef08SMatthias Springer       Operation *newOp = rewriter.create(opState);
6750e37ef08SMatthias Springer       rewriter.replaceOp(op, newOp->getResults());
6760e37ef08SMatthias Springer     }
6770e37ef08SMatthias Springer   }
6780e37ef08SMatthias Springer 
6790e37ef08SMatthias Springer   emitRemark() << numIterations << " iterations";
6800e37ef08SMatthias Springer   return DiagnosedSilenceableFailure::success();
6810e37ef08SMatthias Springer }
6820e37ef08SMatthias Springer 
683d064c480SAlex Zinenko namespace {
6840b52fa90SMatthias Springer // Test pattern to replace an operation with a new op.
6850b52fa90SMatthias Springer class ReplaceWithNewOp : public RewritePattern {
6860b52fa90SMatthias Springer public:
6870b52fa90SMatthias Springer   ReplaceWithNewOp(MLIRContext *context)
6880b52fa90SMatthias Springer       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
6890b52fa90SMatthias Springer 
6900b52fa90SMatthias Springer   LogicalResult matchAndRewrite(Operation *op,
6910b52fa90SMatthias Springer                                 PatternRewriter &rewriter) const override {
6920b52fa90SMatthias Springer     auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op");
6930b52fa90SMatthias Springer     if (!newName)
6940b52fa90SMatthias Springer       return failure();
6950b52fa90SMatthias Springer     Operation *newOp = rewriter.create(
6960b52fa90SMatthias Springer         op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(),
6970b52fa90SMatthias Springer         op->getOperands(), op->getResultTypes());
6980b52fa90SMatthias Springer     rewriter.replaceOp(op, newOp->getResults());
6990b52fa90SMatthias Springer     return success();
7000b52fa90SMatthias Springer   }
7010b52fa90SMatthias Springer };
7020b52fa90SMatthias Springer 
7030b52fa90SMatthias Springer // Test pattern to erase an operation.
7040b52fa90SMatthias Springer class EraseOp : public RewritePattern {
7050b52fa90SMatthias Springer public:
7060b52fa90SMatthias Springer   EraseOp(MLIRContext *context)
7070b52fa90SMatthias Springer       : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
7080b52fa90SMatthias Springer   LogicalResult matchAndRewrite(Operation *op,
7090b52fa90SMatthias Springer                                 PatternRewriter &rewriter) const override {
7100b52fa90SMatthias Springer     rewriter.eraseOp(op);
7110b52fa90SMatthias Springer     return success();
7120b52fa90SMatthias Springer   }
7130b52fa90SMatthias Springer };
7145a10f207SMatthias Springer } // namespace
7150b52fa90SMatthias Springer 
7165a10f207SMatthias Springer void mlir::test::ApplyTestPatternsOp::populatePatterns(
7175a10f207SMatthias Springer     RewritePatternSet &patterns) {
7185a10f207SMatthias Springer   patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
7195a10f207SMatthias Springer }
7205a10f207SMatthias Springer 
721c580bd26SAlex Zinenko void mlir::test::TestReEnterRegionOp::getEffects(
722c580bd26SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
7232c1ae801Sdonald chen   transform::consumesHandle(getOperation()->getOpOperands(), effects);
724c580bd26SAlex Zinenko   transform::modifiesPayload(effects);
725c580bd26SAlex Zinenko }
726c580bd26SAlex Zinenko 
727c580bd26SAlex Zinenko DiagnosedSilenceableFailure
728c580bd26SAlex Zinenko mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter,
729c580bd26SAlex Zinenko                                        transform::TransformResults &results,
730c580bd26SAlex Zinenko                                        transform::TransformState &state) {
731c580bd26SAlex Zinenko 
732c580bd26SAlex Zinenko   SmallVector<SmallVector<transform::MappedValue>> mappings;
733c580bd26SAlex Zinenko   for (BlockArgument arg : getBody().front().getArguments()) {
734c580bd26SAlex Zinenko     mappings.emplace_back(llvm::to_vector(llvm::map_range(
735c580bd26SAlex Zinenko         state.getPayloadOps(getOperand(arg.getArgNumber())),
736c580bd26SAlex Zinenko         [](Operation *op) -> transform::MappedValue { return op; })));
737c580bd26SAlex Zinenko   }
738c580bd26SAlex Zinenko 
739c580bd26SAlex Zinenko   for (int i = 0; i < 4; ++i) {
740c580bd26SAlex Zinenko     auto scope = state.make_region_scope(getBody());
741c580bd26SAlex Zinenko     for (BlockArgument arg : getBody().front().getArguments()) {
742c580bd26SAlex Zinenko       if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()])))
743c580bd26SAlex Zinenko         return DiagnosedSilenceableFailure::definiteFailure();
744c580bd26SAlex Zinenko     }
745c580bd26SAlex Zinenko     for (Operation &op : getBody().front().without_terminator()) {
746c580bd26SAlex Zinenko       DiagnosedSilenceableFailure diag =
747c580bd26SAlex Zinenko           state.applyTransform(cast<transform::TransformOpInterface>(op));
748c580bd26SAlex Zinenko       if (!diag.succeeded())
749c580bd26SAlex Zinenko         return diag;
750c580bd26SAlex Zinenko     }
751c580bd26SAlex Zinenko   }
752c580bd26SAlex Zinenko   return DiagnosedSilenceableFailure::success();
753c580bd26SAlex Zinenko }
754c580bd26SAlex Zinenko 
755c580bd26SAlex Zinenko LogicalResult mlir::test::TestReEnterRegionOp::verify() {
756c580bd26SAlex Zinenko   if (getNumOperands() != getBody().front().getNumArguments()) {
757c580bd26SAlex Zinenko     return emitOpError() << "expects as many operands as block arguments";
758c580bd26SAlex Zinenko   }
759c580bd26SAlex Zinenko   return success();
760c580bd26SAlex Zinenko }
761c580bd26SAlex Zinenko 
7629cc8e458SMatthias Springer DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply(
7639cc8e458SMatthias Springer     transform::TransformRewriter &rewriter,
7649cc8e458SMatthias Springer     transform::TransformResults &results, transform::TransformState &state) {
7659cc8e458SMatthias Springer   auto originalOps = state.getPayloadOps(getOriginal());
7669cc8e458SMatthias Springer   auto replacementOps = state.getPayloadOps(getReplacement());
7679cc8e458SMatthias Springer   if (llvm::range_size(originalOps) != llvm::range_size(replacementOps))
7689cc8e458SMatthias Springer     return emitSilenceableError() << "expected same number of original and "
7699cc8e458SMatthias Springer                                      "replacement payload operations";
7709cc8e458SMatthias Springer   for (const auto &[original, replacement] :
7719cc8e458SMatthias Springer        llvm::zip(originalOps, replacementOps)) {
7729cc8e458SMatthias Springer     if (failed(
7739cc8e458SMatthias Springer             rewriter.notifyPayloadOperationReplaced(original, replacement))) {
7749cc8e458SMatthias Springer       auto diag = emitSilenceableError()
7759cc8e458SMatthias Springer                   << "unable to replace payload op in transform mapping";
7769cc8e458SMatthias Springer       diag.attachNote(original->getLoc()) << "original payload op";
7779cc8e458SMatthias Springer       diag.attachNote(replacement->getLoc()) << "replacement payload op";
7789cc8e458SMatthias Springer       return diag;
7799cc8e458SMatthias Springer     }
7809cc8e458SMatthias Springer   }
7819cc8e458SMatthias Springer   return DiagnosedSilenceableFailure::success();
7829cc8e458SMatthias Springer }
7839cc8e458SMatthias Springer 
7849cc8e458SMatthias Springer void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects(
7859cc8e458SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
7862c1ae801Sdonald chen   transform::onlyReadsHandle(getOriginalMutable(), effects);
7872c1ae801Sdonald chen   transform::onlyReadsHandle(getReplacementMutable(), effects);
7889cc8e458SMatthias Springer }
7899cc8e458SMatthias Springer 
7907dfcd4b7SMatthias Springer DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne(
7917dfcd4b7SMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
7927dfcd4b7SMatthias Springer     transform::ApplyToEachResultList &results,
7937dfcd4b7SMatthias Springer     transform::TransformState &state) {
7947dfcd4b7SMatthias Springer   // Provide some IR that does not verify.
7957dfcd4b7SMatthias Springer   rewriter.setInsertionPointToStart(&target->getRegion(0).front());
7967dfcd4b7SMatthias Springer   rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(),
7977dfcd4b7SMatthias Springer                                       ValueRange(), /*failToVerify=*/true);
7987dfcd4b7SMatthias Springer   return DiagnosedSilenceableFailure::success();
7997dfcd4b7SMatthias Springer }
8007dfcd4b7SMatthias Springer 
8017dfcd4b7SMatthias Springer void mlir::test::TestProduceInvalidIR::getEffects(
8027dfcd4b7SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
8032c1ae801Sdonald chen   transform::onlyReadsHandle(getTargetMutable(), effects);
8047dfcd4b7SMatthias Springer   transform::modifiesPayload(effects);
8057dfcd4b7SMatthias Springer }
8067dfcd4b7SMatthias Springer 
8076634d44eSAmy Wang DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply(
8086634d44eSAmy Wang     transform::TransformRewriter &rewriter,
8096634d44eSAmy Wang     transform::TransformResults &results, transform::TransformState &state) {
8106634d44eSAmy Wang   std::string opName =
8116634d44eSAmy Wang       this->getOperationName().str() + "_" + getTypeAttr().str();
8126634d44eSAmy Wang   TransformStateInitializerExtension *initExt =
8136634d44eSAmy Wang       state.getExtension<TransformStateInitializerExtension>();
8146634d44eSAmy Wang   if (!initExt) {
8156634d44eSAmy Wang     emitRemark() << "\nSpecified extension not found, adding a new one!\n";
8166634d44eSAmy Wang     SmallVector<std::string> opCollection = {opName};
8176634d44eSAmy Wang     state.addExtension<TransformStateInitializerExtension>(1, opCollection);
8186634d44eSAmy Wang   } else {
8196634d44eSAmy Wang     initExt->setNumOp(initExt->getNumOp() + 1);
8206634d44eSAmy Wang     initExt->pushRegisteredOps(opName);
8216634d44eSAmy Wang     InFlightDiagnostic diag = emitRemark()
8226634d44eSAmy Wang                               << "Number of currently registered op: "
8236634d44eSAmy Wang                               << initExt->getNumOp() << "\n"
8246634d44eSAmy Wang                               << initExt->printMessage() << "\n";
8256634d44eSAmy Wang   }
8266634d44eSAmy Wang   return DiagnosedSilenceableFailure::success();
8276634d44eSAmy Wang }
8286634d44eSAmy Wang 
8295a10f207SMatthias Springer namespace {
830bcfdb3e4SMatthias Springer /// Test conversion pattern that replaces ops with the "replace_with_new_op"
831bcfdb3e4SMatthias Springer /// attribute with "test.new_op".
832bcfdb3e4SMatthias Springer class ReplaceWithNewOpConversion : public ConversionPattern {
833bcfdb3e4SMatthias Springer public:
834bcfdb3e4SMatthias Springer   ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context)
835bcfdb3e4SMatthias Springer       : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(),
836bcfdb3e4SMatthias Springer                           /*benefit=*/1, context) {}
837bcfdb3e4SMatthias Springer 
838bcfdb3e4SMatthias Springer   LogicalResult
839bcfdb3e4SMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
840bcfdb3e4SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
841bcfdb3e4SMatthias Springer     if (!op->hasAttr("replace_with_new_op"))
842bcfdb3e4SMatthias Springer       return failure();
843bcfdb3e4SMatthias Springer     SmallVector<Type> newResultTypes;
844bcfdb3e4SMatthias Springer     if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
845bcfdb3e4SMatthias Springer                                                 newResultTypes)))
846bcfdb3e4SMatthias Springer       return failure();
847bcfdb3e4SMatthias Springer     Operation *newOp = rewriter.create(
848bcfdb3e4SMatthias Springer         op->getLoc(),
849bcfdb3e4SMatthias Springer         OperationName("test.new_op", op->getContext()).getIdentifier(),
850bcfdb3e4SMatthias Springer         operands, newResultTypes);
851bcfdb3e4SMatthias Springer     rewriter.replaceOp(op, newOp->getResults());
852bcfdb3e4SMatthias Springer     return success();
853bcfdb3e4SMatthias Springer   }
854bcfdb3e4SMatthias Springer };
855bcfdb3e4SMatthias Springer } // namespace
856bcfdb3e4SMatthias Springer 
857bcfdb3e4SMatthias Springer void mlir::test::ApplyTestConversionPatternsOp::populatePatterns(
858bcfdb3e4SMatthias Springer     TypeConverter &typeConverter, RewritePatternSet &patterns) {
859bcfdb3e4SMatthias Springer   patterns.insert<ReplaceWithNewOpConversion>(typeConverter,
860bcfdb3e4SMatthias Springer                                               patterns.getContext());
861bcfdb3e4SMatthias Springer }
862bcfdb3e4SMatthias Springer 
863bcfdb3e4SMatthias Springer namespace {
864bcfdb3e4SMatthias Springer /// Test type converter that converts tensor types to memref types.
865bcfdb3e4SMatthias Springer class TestTypeConverter : public TypeConverter {
866bcfdb3e4SMatthias Springer public:
867bcfdb3e4SMatthias Springer   TestTypeConverter() {
868e2d39f79SChristopher Bate     addConversion([](Type t) { return t; });
869bcfdb3e4SMatthias Springer     addConversion([](RankedTensorType type) -> Type {
870bcfdb3e4SMatthias Springer       return MemRefType::get(type.getShape(), type.getElementType());
871bcfdb3e4SMatthias Springer     });
872e2d39f79SChristopher Bate     auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType,
873bcfdb3e4SMatthias Springer                                        ValueRange inputs,
874*f18c3e4eSMatthias Springer                                        Location loc) -> Value {
875bcfdb3e4SMatthias Springer       if (inputs.size() != 1)
876*f18c3e4eSMatthias Springer         return Value();
877bcfdb3e4SMatthias Springer       return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
878bcfdb3e4SMatthias Springer           .getResult(0);
879e2d39f79SChristopher Bate     };
880e2d39f79SChristopher Bate     addSourceMaterialization(unrealizedCastConverter);
881e2d39f79SChristopher Bate     addTargetMaterialization(unrealizedCastConverter);
882bcfdb3e4SMatthias Springer   }
883bcfdb3e4SMatthias Springer };
884bcfdb3e4SMatthias Springer } // namespace
885bcfdb3e4SMatthias Springer 
886bcfdb3e4SMatthias Springer std::unique_ptr<::mlir::TypeConverter>
887bcfdb3e4SMatthias Springer mlir::test::TestTypeConverterOp::getTypeConverter() {
888bcfdb3e4SMatthias Springer   return std::make_unique<TestTypeConverter>();
889bcfdb3e4SMatthias Springer }
890bcfdb3e4SMatthias Springer 
891bcfdb3e4SMatthias Springer namespace {
892d064c480SAlex Zinenko /// Test extension of the Transform dialect. Registers additional ops and
893d064c480SAlex Zinenko /// declares PDL as dependent dialect since the additional ops are using PDL
894d064c480SAlex Zinenko /// types for operands and results.
895d064c480SAlex Zinenko class TestTransformDialectExtension
896d064c480SAlex Zinenko     : public transform::TransformDialectExtension<
897d064c480SAlex Zinenko           TestTransformDialectExtension> {
898d064c480SAlex Zinenko public:
89984cc1865SNikhil Kalra   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension)
90084cc1865SNikhil Kalra 
901333ee218SAlex Zinenko   using Base::Base;
902333ee218SAlex Zinenko 
903333ee218SAlex Zinenko   void init() {
904d064c480SAlex Zinenko     declareDependentDialect<pdl::PDLDialect>();
905d064c480SAlex Zinenko     registerTransformOps<TestTransformOp,
90630f22429SAlex Zinenko                          TestTransformUnrestrictedOpNoInterface,
907d064c480SAlex Zinenko #define GET_OP_LIST
908d064c480SAlex Zinenko #include "TestTransformDialectExtension.cpp.inc"
909d064c480SAlex Zinenko                          >();
910bba85ebdSAlex Zinenko     registerTypes<
911bba85ebdSAlex Zinenko #define GET_TYPEDEF_LIST
912bba85ebdSAlex Zinenko #include "TestTransformDialectExtensionTypes.cpp.inc"
913bba85ebdSAlex Zinenko         >();
91494d608d4SAlex Zinenko 
9158ec28af8SMatthias Gehre     auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
91694d608d4SAlex Zinenko                                 ArrayRef<PDLValue> pdlValues) {
91794d608d4SAlex Zinenko       for (const PDLValue &pdlValue : pdlValues) {
91894d608d4SAlex Zinenko         if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
91994d608d4SAlex Zinenko           op->emitWarning() << "from PDL constraint";
92094d608d4SAlex Zinenko         }
92194d608d4SAlex Zinenko       }
92294d608d4SAlex Zinenko       return success();
92394d608d4SAlex Zinenko     };
92494d608d4SAlex Zinenko 
92594d608d4SAlex Zinenko     addDialectDataInitializer<transform::PDLMatchHooks>(
92694d608d4SAlex Zinenko         [&](transform::PDLMatchHooks &hooks) {
92794d608d4SAlex Zinenko           llvm::StringMap<PDLConstraintFunction> constraints;
92894d608d4SAlex Zinenko           constraints.try_emplace("verbose_constraint", verboseConstraint);
92994d608d4SAlex Zinenko           hooks.mergeInPDLMatchHooks(std::move(constraints));
93094d608d4SAlex Zinenko         });
931d064c480SAlex Zinenko   }
932d064c480SAlex Zinenko };
933d064c480SAlex Zinenko } // namespace
934d064c480SAlex Zinenko 
935bba85ebdSAlex Zinenko // These are automatically generated by ODS but are not used as the Transform
936bba85ebdSAlex Zinenko // dialect uses a different dispatch mechanism to support dialect extensions.
937bba85ebdSAlex Zinenko LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
938bba85ebdSAlex Zinenko generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
939bba85ebdSAlex Zinenko LLVM_ATTRIBUTE_UNUSED static LogicalResult
940bba85ebdSAlex Zinenko generatedTypePrinter(Type def, AsmPrinter &printer);
941bba85ebdSAlex Zinenko 
942bba85ebdSAlex Zinenko #define GET_TYPEDEF_CLASSES
943bba85ebdSAlex Zinenko #include "TestTransformDialectExtensionTypes.cpp.inc"
944bba85ebdSAlex Zinenko 
945ed02fa81SAlex Zinenko #define GET_OP_CLASSES
946ed02fa81SAlex Zinenko #include "TestTransformDialectExtension.cpp.inc"
947ed02fa81SAlex Zinenko 
948d064c480SAlex Zinenko void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
949d064c480SAlex Zinenko   registry.addExtensions<TestTransformDialectExtension>();
950d064c480SAlex Zinenko }
951