xref: /llvm-project/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (revision f18c3e4e7335df282c468b6dff3d29be1822a96d)
1 //===- TestTransformDialectExtension.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines an extension of the MLIR Transform dialect for testing
10 // purposes.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestTransformDialectExtension.h"
15 #include "TestTransformStateExtension.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/Dialect/Transform/IR/TransformOps.h"
19 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20 #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
21 #include "mlir/IR/OpImplementation.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Compiler.h"
26 #include "llvm/Support/raw_ostream.h"
27 
28 using namespace mlir;
29 
30 namespace {
31 /// Simple transform op defined outside of the dialect. Just emits a remark when
32 /// applied. This op is defined in C++ to test that C++ definitions also work
33 /// for op injection into the Transform dialect.
34 class TestTransformOp
35     : public Op<TestTransformOp, transform::TransformOpInterface::Trait,
36                 MemoryEffectOpInterface::Trait> {
37 public:
38   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
39 
40   using Op::Op;
41 
42   static ArrayRef<StringRef> getAttributeNames() { return {}; }
43 
44   static constexpr llvm::StringLiteral getOperationName() {
45     return llvm::StringLiteral("transform.test_transform_op");
46   }
47 
48   DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
49                                     transform::TransformResults &results,
50                                     transform::TransformState &state) {
51     InFlightDiagnostic remark = emitRemark() << "applying transformation";
52     if (Attribute message = getMessage())
53       remark << " " << message;
54 
55     return DiagnosedSilenceableFailure::success();
56   }
57 
58   Attribute getMessage() {
59     return getOperation()->getDiscardableAttr("message");
60   }
61 
62   static ParseResult parse(OpAsmParser &parser, OperationState &state) {
63     StringAttr message;
64     OptionalParseResult result = parser.parseOptionalAttribute(message);
65     if (!result.has_value())
66       return success();
67 
68     if (result.value().succeeded())
69       state.addAttribute("message", message);
70     return result.value();
71   }
72 
73   void print(OpAsmPrinter &printer) {
74     if (getMessage())
75       printer << " " << getMessage();
76   }
77 
78   // No side effects.
79   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
80 };
81 
82 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
83 /// in cases where it is attached to ops that do not comply with the trait
84 /// requirements. This op cannot be defined in ODS because ODS generates strict
85 /// verifiers that overalp with those in the trait and run earlier.
86 class TestTransformUnrestrictedOpNoInterface
87     : public Op<TestTransformUnrestrictedOpNoInterface,
88                 transform::PossibleTopLevelTransformOpTrait,
89                 transform::TransformOpInterface::Trait,
90                 MemoryEffectOpInterface::Trait> {
91 public:
92   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
93       TestTransformUnrestrictedOpNoInterface)
94 
95   using Op::Op;
96 
97   static ArrayRef<StringRef> getAttributeNames() { return {}; }
98 
99   static constexpr llvm::StringLiteral getOperationName() {
100     return llvm::StringLiteral(
101         "transform.test_transform_unrestricted_op_no_interface");
102   }
103 
104   DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
105                                     transform::TransformResults &results,
106                                     transform::TransformState &state) {
107     return DiagnosedSilenceableFailure::success();
108   }
109 
110   // No side effects.
111   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
112 };
113 } // namespace
114 
115 DiagnosedSilenceableFailure
116 mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
117     transform::TransformRewriter &rewriter,
118     transform::TransformResults &results, transform::TransformState &state) {
119   if (getOperation()->getNumOperands() != 0) {
120     results.set(cast<OpResult>(getResult()),
121                 {getOperation()->getOperand(0).getDefiningOp()});
122   } else {
123     results.set(cast<OpResult>(getResult()), {getOperation()});
124   }
125   return DiagnosedSilenceableFailure::success();
126 }
127 
128 void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects(
129     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
130   if (getOperand())
131     transform::onlyReadsHandle(getOperandMutable(), effects);
132   transform::producesHandle(getOperation()->getOpResults(), effects);
133 }
134 
135 DiagnosedSilenceableFailure
136 mlir::test::TestProduceValueHandleToSelfOperand::apply(
137     transform::TransformRewriter &rewriter,
138     transform::TransformResults &results, transform::TransformState &state) {
139   results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
140   return DiagnosedSilenceableFailure::success();
141 }
142 
143 void mlir::test::TestProduceValueHandleToSelfOperand::getEffects(
144     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
145   transform::onlyReadsHandle(getInMutable(), effects);
146   transform::producesHandle(getOperation()->getOpResults(), effects);
147   transform::onlyReadsPayload(effects);
148 }
149 
150 DiagnosedSilenceableFailure
151 mlir::test::TestProduceValueHandleToResult::applyToOne(
152     transform::TransformRewriter &rewriter, Operation *target,
153     transform::ApplyToEachResultList &results,
154     transform::TransformState &state) {
155   if (target->getNumResults() <= getNumber())
156     return emitSilenceableError() << "payload has no result #" << getNumber();
157   results.push_back(target->getResult(getNumber()));
158   return DiagnosedSilenceableFailure::success();
159 }
160 
161 void mlir::test::TestProduceValueHandleToResult::getEffects(
162     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
163   transform::onlyReadsHandle(getInMutable(), effects);
164   transform::producesHandle(getOperation()->getOpResults(), effects);
165   transform::onlyReadsPayload(effects);
166 }
167 
168 DiagnosedSilenceableFailure
169 mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne(
170     transform::TransformRewriter &rewriter, Operation *target,
171     transform::ApplyToEachResultList &results,
172     transform::TransformState &state) {
173   if (!target->getBlock())
174     return emitSilenceableError() << "payload has no parent block";
175   if (target->getBlock()->getNumArguments() <= getNumber())
176     return emitSilenceableError()
177            << "parent of the payload has no argument #" << getNumber();
178   results.push_back(target->getBlock()->getArgument(getNumber()));
179   return DiagnosedSilenceableFailure::success();
180 }
181 
182 void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects(
183     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
184   transform::onlyReadsHandle(getInMutable(), effects);
185   transform::producesHandle(getOperation()->getOpResults(), effects);
186   transform::onlyReadsPayload(effects);
187 }
188 
189 bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() {
190   return getAllowRepeatedHandles();
191 }
192 
193 DiagnosedSilenceableFailure
194 mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter,
195                                       transform::TransformResults &results,
196                                       transform::TransformState &state) {
197   return DiagnosedSilenceableFailure::success();
198 }
199 
200 void mlir::test::TestConsumeOperand::getEffects(
201     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
202   transform::consumesHandle(getOperation()->getOpOperands(), effects);
203   if (getSecondOperand())
204     transform::consumesHandle(getSecondOperandMutable(), effects);
205   transform::modifiesPayload(effects);
206 }
207 
208 DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply(
209     transform::TransformRewriter &rewriter,
210     transform::TransformResults &results, transform::TransformState &state) {
211   auto payload = state.getPayloadOps(getOperand());
212   assert(llvm::hasSingleElement(payload) && "expected a single target op");
213   if ((*payload.begin())->getName().getStringRef() != getOpKind()) {
214     return emitSilenceableError()
215            << "op expected the operand to be associated a payload op of kind "
216            << getOpKind() << " got "
217            << (*payload.begin())->getName().getStringRef();
218   }
219 
220   emitRemark() << "succeeded";
221   return DiagnosedSilenceableFailure::success();
222 }
223 
224 void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects(
225     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
226   transform::consumesHandle(getOperation()->getOpOperands(), effects);
227   transform::modifiesPayload(effects);
228 }
229 
230 DiagnosedSilenceableFailure
231 mlir::test::TestSucceedIfOperandOfOpKind::matchOperation(
232     Operation *op, transform::TransformResults &results,
233     transform::TransformState &state) {
234   if (op->getName().getStringRef() != getOpKind()) {
235     return emitSilenceableError()
236            << "op expected the operand to be associated with a payload op of "
237               "kind "
238            << getOpKind() << " got " << op->getName().getStringRef();
239   }
240   return DiagnosedSilenceableFailure::success();
241 }
242 
243 void mlir::test::TestSucceedIfOperandOfOpKind::getEffects(
244     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
245   transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
246   transform::onlyReadsPayload(effects);
247 }
248 
249 DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply(
250     transform::TransformRewriter &rewriter,
251     transform::TransformResults &results, transform::TransformState &state) {
252   state.addExtension<TestTransformStateExtension>(getMessageAttr());
253   return DiagnosedSilenceableFailure::success();
254 }
255 
256 DiagnosedSilenceableFailure
257 mlir::test::TestCheckIfTestExtensionPresentOp::apply(
258     transform::TransformRewriter &rewriter,
259     transform::TransformResults &results, transform::TransformState &state) {
260   auto *extension = state.getExtension<TestTransformStateExtension>();
261   if (!extension) {
262     emitRemark() << "extension absent";
263     return DiagnosedSilenceableFailure::success();
264   }
265 
266   InFlightDiagnostic diag = emitRemark()
267                             << "extension present, " << extension->getMessage();
268   for (Operation *payload : state.getPayloadOps(getOperand())) {
269     diag.attachNote(payload->getLoc()) << "associated payload op";
270 #ifndef NDEBUG
271     SmallVector<Value> handles;
272     assert(succeeded(state.getHandlesForPayloadOp(payload, handles)));
273     assert(llvm::is_contained(handles, getOperand()) &&
274            "inconsistent mapping between transform IR handles and payload IR "
275            "operations");
276 #endif // NDEBUG
277   }
278 
279   return DiagnosedSilenceableFailure::success();
280 }
281 
282 void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects(
283     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
284   transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
285   transform::onlyReadsPayload(effects);
286 }
287 
288 DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
289     transform::TransformRewriter &rewriter,
290     transform::TransformResults &results, transform::TransformState &state) {
291   auto *extension = state.getExtension<TestTransformStateExtension>();
292   if (!extension)
293     return emitDefiniteFailure("TestTransformStateExtension missing");
294 
295   if (failed(extension->updateMapping(
296           *state.getPayloadOps(getOperand()).begin(), getOperation())))
297     return DiagnosedSilenceableFailure::definiteFailure();
298   if (getNumResults() > 0)
299     results.set(cast<OpResult>(getResult(0)), {getOperation()});
300   return DiagnosedSilenceableFailure::success();
301 }
302 
303 void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects(
304     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
305   transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
306   transform::producesHandle(getOperation()->getOpResults(), effects);
307   transform::onlyReadsPayload(effects);
308 }
309 
310 DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
311     transform::TransformRewriter &rewriter,
312     transform::TransformResults &results, transform::TransformState &state) {
313   state.removeExtension<TestTransformStateExtension>();
314   return DiagnosedSilenceableFailure::success();
315 }
316 
317 DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(
318     transform::TransformRewriter &rewriter,
319     transform::TransformResults &results, transform::TransformState &state) {
320   auto payloadOps = state.getPayloadOps(getTarget());
321   auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
322   results.set(llvm::cast<OpResult>(getResult()), reversedOps);
323   return DiagnosedSilenceableFailure::success();
324 }
325 
326 DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply(
327     transform::TransformRewriter &rewriter,
328     transform::TransformResults &results, transform::TransformState &state) {
329   return DiagnosedSilenceableFailure::success();
330 }
331 
332 void mlir::test::TestTransformOpWithRegions::getEffects(
333     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
334 
335 DiagnosedSilenceableFailure
336 mlir::test::TestBranchingTransformOpTerminator::apply(
337     transform::TransformRewriter &rewriter,
338     transform::TransformResults &results, transform::TransformState &state) {
339   return DiagnosedSilenceableFailure::success();
340 }
341 
342 void mlir::test::TestBranchingTransformOpTerminator::getEffects(
343     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
344 
345 DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
346     transform::TransformRewriter &rewriter,
347     transform::TransformResults &results, transform::TransformState &state) {
348   emitRemark() << getRemark();
349   for (Operation *op : state.getPayloadOps(getTarget()))
350     rewriter.eraseOp(op);
351 
352   if (getFailAfterErase())
353     return emitSilenceableError() << "silenceable error";
354   return DiagnosedSilenceableFailure::success();
355 }
356 
357 void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects(
358     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
359   transform::consumesHandle(getTargetMutable(), effects);
360   transform::modifiesPayload(effects);
361 }
362 
363 DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
364     transform::TransformRewriter &rewriter, Operation *target,
365     transform::ApplyToEachResultList &results,
366     transform::TransformState &state) {
367   OperationState opState(target->getLoc(), "foo");
368   results.push_back(OpBuilder(target).create(opState));
369   return DiagnosedSilenceableFailure::success();
370 }
371 
372 DiagnosedSilenceableFailure
373 mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
374     transform::TransformRewriter &rewriter, Operation *target,
375     transform::ApplyToEachResultList &results,
376     transform::TransformState &state) {
377   static int count = 0;
378   if (count++ == 0) {
379     OperationState opState(target->getLoc(), "foo");
380     results.push_back(OpBuilder(target).create(opState));
381   }
382   return DiagnosedSilenceableFailure::success();
383 }
384 
385 DiagnosedSilenceableFailure
386 mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
387     transform::TransformRewriter &rewriter, Operation *target,
388     transform::ApplyToEachResultList &results,
389     transform::TransformState &state) {
390   OperationState opState(target->getLoc(), "foo");
391   results.push_back(OpBuilder(target).create(opState));
392   results.push_back(OpBuilder(target).create(opState));
393   return DiagnosedSilenceableFailure::success();
394 }
395 
396 DiagnosedSilenceableFailure
397 mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
398     transform::TransformRewriter &rewriter, Operation *target,
399     transform::ApplyToEachResultList &results,
400     transform::TransformState &state) {
401   OperationState opState(target->getLoc(), "foo");
402   results.push_back(nullptr);
403   results.push_back(OpBuilder(target).create(opState));
404   return DiagnosedSilenceableFailure::success();
405 }
406 
407 DiagnosedSilenceableFailure
408 mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
409     transform::TransformRewriter &rewriter, Operation *target,
410     transform::ApplyToEachResultList &results,
411     transform::TransformState &state) {
412   if (target->hasAttr("target_me"))
413     return DiagnosedSilenceableFailure::success();
414   return emitDefaultSilenceableFailure(target);
415 }
416 
417 DiagnosedSilenceableFailure
418 mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
419                                      transform::TransformResults &results,
420                                      transform::TransformState &state) {
421   results.set(llvm::cast<OpResult>(getCopy()),
422               state.getPayloadOps(getHandle()));
423   return DiagnosedSilenceableFailure::success();
424 }
425 
426 void mlir::test::TestCopyPayloadOp::getEffects(
427     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
428   transform::onlyReadsHandle(getHandleMutable(), effects);
429   transform::producesHandle(getOperation()->getOpResults(), effects);
430   transform::onlyReadsPayload(effects);
431 }
432 
433 DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
434     Location loc, ArrayRef<Operation *> payload) const {
435   if (payload.empty())
436     return DiagnosedSilenceableFailure::success();
437 
438   for (Operation *op : payload) {
439     if (op->getName().getDialectNamespace() != "test") {
440       return emitSilenceableError(loc) << "expected the payload operation to "
441                                           "belong to the 'test' dialect";
442     }
443   }
444 
445   return DiagnosedSilenceableFailure::success();
446 }
447 
448 DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
449     Location loc, ArrayRef<Attribute> payload) const {
450   for (Attribute attr : payload) {
451     auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
452     if (integerAttr && integerAttr.getType().isSignlessInteger(32))
453       continue;
454     return emitSilenceableError(loc)
455            << "expected the parameter to be a i32 integer attribute";
456   }
457 
458   return DiagnosedSilenceableFailure::success();
459 }
460 
461 void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
462     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
463   transform::onlyReadsHandle(getTargetMutable(), effects);
464 }
465 
466 DiagnosedSilenceableFailure
467 mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
468     transform::TransformRewriter &rewriter,
469     transform::TransformResults &results, transform::TransformState &state) {
470   int64_t count = 0;
471   for (Operation *op : state.getPayloadOps(getTarget())) {
472     op->walk([&](Operation *nested) {
473       SmallVector<Value> handles;
474       (void)state.getHandlesForPayloadOp(nested, handles);
475       count += handles.size();
476     });
477   }
478   emitRemark() << count << " handles nested under";
479   return DiagnosedSilenceableFailure::success();
480 }
481 
482 DiagnosedSilenceableFailure
483 mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter,
484                                     transform::TransformResults &results,
485                                     transform::TransformState &state) {
486   SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0);
487   if (Value param = getParam()) {
488     values = llvm::to_vector(
489         llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
490           return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue(
491               UINT32_MAX);
492         }));
493   }
494 
495   Builder builder(getContext());
496   SmallVector<Attribute> result = llvm::to_vector(
497       llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
498         return builder.getI32IntegerAttr(value + getAddendum());
499       }));
500   results.setParams(llvm::cast<OpResult>(getResult()), result);
501   return DiagnosedSilenceableFailure::success();
502 }
503 
504 DiagnosedSilenceableFailure
505 mlir::test::TestProduceParamWithNumberOfTestOps::apply(
506     transform::TransformRewriter &rewriter,
507     transform::TransformResults &results, transform::TransformState &state) {
508   Builder builder(getContext());
509   SmallVector<Attribute> result = llvm::to_vector(
510       llvm::map_range(state.getPayloadOps(getHandle()),
511                       [&builder](Operation *payload) -> Attribute {
512                         int32_t count = 0;
513                         payload->walk([&count](Operation *op) {
514                           if (op->getName().getDialectNamespace() == "test")
515                             ++count;
516                         });
517                         return builder.getI32IntegerAttr(count);
518                       }));
519   results.setParams(llvm::cast<OpResult>(getResult()), result);
520   return DiagnosedSilenceableFailure::success();
521 }
522 
523 DiagnosedSilenceableFailure
524 mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter,
525                                       transform::TransformResults &results,
526                                       transform::TransformState &state) {
527   results.setParams(llvm::cast<OpResult>(getResult()), getAttr());
528   return DiagnosedSilenceableFailure::success();
529 }
530 
531 void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects(
532     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
533   transform::onlyReadsHandle(getInMutable(), effects);
534   transform::producesHandle(getOperation()->getOpResults(), effects);
535 }
536 
537 DiagnosedSilenceableFailure
538 mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
539     transform::TransformRewriter &rewriter, Operation *target,
540     ::transform::ApplyToEachResultList &results,
541     ::transform::TransformState &state) {
542   Builder builder(getContext());
543   if (getFirstResultIsParam()) {
544     results.push_back(builder.getI64IntegerAttr(0));
545   } else if (getFirstResultIsNull()) {
546     results.push_back(nullptr);
547   } else {
548     results.push_back(*state.getPayloadOps(getIn()).begin());
549   }
550 
551   if (getSecondResultIsHandle()) {
552     results.push_back(*state.getPayloadOps(getIn()).begin());
553   } else {
554     results.push_back(builder.getI64IntegerAttr(42));
555   }
556 
557   return DiagnosedSilenceableFailure::success();
558 }
559 
560 void mlir::test::TestProduceNullPayloadOp::getEffects(
561     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
562   transform::producesHandle(getOperation()->getOpResults(), effects);
563 }
564 
565 DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
566     transform::TransformRewriter &rewriter,
567     transform::TransformResults &results, transform::TransformState &state) {
568   SmallVector<Operation *, 1> null({nullptr});
569   results.set(llvm::cast<OpResult>(getOut()), null);
570   return DiagnosedSilenceableFailure::success();
571 }
572 
573 DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply(
574     transform::TransformRewriter &rewriter,
575     transform::TransformResults &results, transform::TransformState &state) {
576   results.set(cast<OpResult>(getOut()), {});
577   return DiagnosedSilenceableFailure::success();
578 }
579 
580 void mlir::test::TestProduceNullParamOp::getEffects(
581     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
582   transform::producesHandle(getOperation()->getOpResults(), effects);
583 }
584 
585 DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(
586     transform::TransformRewriter &rewriter,
587     transform::TransformResults &results, transform::TransformState &state) {
588   results.setParams(llvm::cast<OpResult>(getOut()), Attribute());
589   return DiagnosedSilenceableFailure::success();
590 }
591 
592 void mlir::test::TestProduceNullValueOp::getEffects(
593     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
594   transform::producesHandle(getOperation()->getOpResults(), effects);
595 }
596 
597 DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
598     transform::TransformRewriter &rewriter,
599     transform::TransformResults &results, transform::TransformState &state) {
600   results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
601   return DiagnosedSilenceableFailure::success();
602 }
603 
604 void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
605     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
606   if (getHasOperandEffect())
607     transform::consumesHandle(getInMutable(), effects);
608 
609   if (getHasResultEffect()) {
610     transform::producesHandle(getOperation()->getOpResults(), effects);
611   } else {
612     effects.emplace_back(MemoryEffects::Read::get(),
613                          llvm::cast<OpResult>(getOut()),
614                          transform::TransformMappingResource::get());
615   }
616 
617   if (getModifiesPayload())
618     transform::modifiesPayload(effects);
619 }
620 
621 DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
622     transform::TransformRewriter &rewriter,
623     transform::TransformResults &results, transform::TransformState &state) {
624   results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn()));
625   return DiagnosedSilenceableFailure::success();
626 }
627 
628 void mlir::test::TestTrackedRewriteOp::getEffects(
629     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
630   transform::onlyReadsHandle(getInMutable(), effects);
631   transform::modifiesPayload(effects);
632 }
633 
634 void mlir::test::TestDummyPayloadOp::getEffects(
635     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
636   transform::producesHandle(getOperation()->getOpResults(), effects);
637 }
638 
639 LogicalResult mlir::test::TestDummyPayloadOp::verify() {
640   if (getFailToVerify())
641     return emitOpError() << "fail_to_verify is set";
642   return success();
643 }
644 
645 DiagnosedSilenceableFailure
646 mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter,
647                                         transform::TransformResults &results,
648                                         transform::TransformState &state) {
649   int64_t numIterations = 0;
650 
651   // `getPayloadOps` returns an iterator that skips ops that are erased in the
652   // loop body. Replacement ops are not enumerated.
653   for (Operation *op : state.getPayloadOps(getIn())) {
654     ++numIterations;
655     (void)op;
656 
657     // Erase all payload ops. The outer loop should have only one iteration.
658     for (Operation *op : state.getPayloadOps(getIn())) {
659       rewriter.setInsertionPoint(op);
660       if (op->hasAttr("erase_me")) {
661         rewriter.eraseOp(op);
662         continue;
663       }
664       if (!op->hasAttr("replace_me")) {
665         continue;
666       }
667 
668       SmallVector<NamedAttribute> attributes;
669       attributes.emplace_back(rewriter.getStringAttr("new_op"),
670                               rewriter.getUnitAttr());
671       OperationState opState(op->getLoc(), op->getName().getIdentifier(),
672                              /*operands=*/ValueRange(),
673                              /*types=*/op->getResultTypes(), attributes);
674       Operation *newOp = rewriter.create(opState);
675       rewriter.replaceOp(op, newOp->getResults());
676     }
677   }
678 
679   emitRemark() << numIterations << " iterations";
680   return DiagnosedSilenceableFailure::success();
681 }
682 
683 namespace {
684 // Test pattern to replace an operation with a new op.
685 class ReplaceWithNewOp : public RewritePattern {
686 public:
687   ReplaceWithNewOp(MLIRContext *context)
688       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
689 
690   LogicalResult matchAndRewrite(Operation *op,
691                                 PatternRewriter &rewriter) const override {
692     auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op");
693     if (!newName)
694       return failure();
695     Operation *newOp = rewriter.create(
696         op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(),
697         op->getOperands(), op->getResultTypes());
698     rewriter.replaceOp(op, newOp->getResults());
699     return success();
700   }
701 };
702 
703 // Test pattern to erase an operation.
704 class EraseOp : public RewritePattern {
705 public:
706   EraseOp(MLIRContext *context)
707       : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
708   LogicalResult matchAndRewrite(Operation *op,
709                                 PatternRewriter &rewriter) const override {
710     rewriter.eraseOp(op);
711     return success();
712   }
713 };
714 } // namespace
715 
716 void mlir::test::ApplyTestPatternsOp::populatePatterns(
717     RewritePatternSet &patterns) {
718   patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
719 }
720 
721 void mlir::test::TestReEnterRegionOp::getEffects(
722     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
723   transform::consumesHandle(getOperation()->getOpOperands(), effects);
724   transform::modifiesPayload(effects);
725 }
726 
727 DiagnosedSilenceableFailure
728 mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter,
729                                        transform::TransformResults &results,
730                                        transform::TransformState &state) {
731 
732   SmallVector<SmallVector<transform::MappedValue>> mappings;
733   for (BlockArgument arg : getBody().front().getArguments()) {
734     mappings.emplace_back(llvm::to_vector(llvm::map_range(
735         state.getPayloadOps(getOperand(arg.getArgNumber())),
736         [](Operation *op) -> transform::MappedValue { return op; })));
737   }
738 
739   for (int i = 0; i < 4; ++i) {
740     auto scope = state.make_region_scope(getBody());
741     for (BlockArgument arg : getBody().front().getArguments()) {
742       if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()])))
743         return DiagnosedSilenceableFailure::definiteFailure();
744     }
745     for (Operation &op : getBody().front().without_terminator()) {
746       DiagnosedSilenceableFailure diag =
747           state.applyTransform(cast<transform::TransformOpInterface>(op));
748       if (!diag.succeeded())
749         return diag;
750     }
751   }
752   return DiagnosedSilenceableFailure::success();
753 }
754 
755 LogicalResult mlir::test::TestReEnterRegionOp::verify() {
756   if (getNumOperands() != getBody().front().getNumArguments()) {
757     return emitOpError() << "expects as many operands as block arguments";
758   }
759   return success();
760 }
761 
762 DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply(
763     transform::TransformRewriter &rewriter,
764     transform::TransformResults &results, transform::TransformState &state) {
765   auto originalOps = state.getPayloadOps(getOriginal());
766   auto replacementOps = state.getPayloadOps(getReplacement());
767   if (llvm::range_size(originalOps) != llvm::range_size(replacementOps))
768     return emitSilenceableError() << "expected same number of original and "
769                                      "replacement payload operations";
770   for (const auto &[original, replacement] :
771        llvm::zip(originalOps, replacementOps)) {
772     if (failed(
773             rewriter.notifyPayloadOperationReplaced(original, replacement))) {
774       auto diag = emitSilenceableError()
775                   << "unable to replace payload op in transform mapping";
776       diag.attachNote(original->getLoc()) << "original payload op";
777       diag.attachNote(replacement->getLoc()) << "replacement payload op";
778       return diag;
779     }
780   }
781   return DiagnosedSilenceableFailure::success();
782 }
783 
784 void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects(
785     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
786   transform::onlyReadsHandle(getOriginalMutable(), effects);
787   transform::onlyReadsHandle(getReplacementMutable(), effects);
788 }
789 
790 DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne(
791     transform::TransformRewriter &rewriter, Operation *target,
792     transform::ApplyToEachResultList &results,
793     transform::TransformState &state) {
794   // Provide some IR that does not verify.
795   rewriter.setInsertionPointToStart(&target->getRegion(0).front());
796   rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(),
797                                       ValueRange(), /*failToVerify=*/true);
798   return DiagnosedSilenceableFailure::success();
799 }
800 
801 void mlir::test::TestProduceInvalidIR::getEffects(
802     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
803   transform::onlyReadsHandle(getTargetMutable(), effects);
804   transform::modifiesPayload(effects);
805 }
806 
807 DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply(
808     transform::TransformRewriter &rewriter,
809     transform::TransformResults &results, transform::TransformState &state) {
810   std::string opName =
811       this->getOperationName().str() + "_" + getTypeAttr().str();
812   TransformStateInitializerExtension *initExt =
813       state.getExtension<TransformStateInitializerExtension>();
814   if (!initExt) {
815     emitRemark() << "\nSpecified extension not found, adding a new one!\n";
816     SmallVector<std::string> opCollection = {opName};
817     state.addExtension<TransformStateInitializerExtension>(1, opCollection);
818   } else {
819     initExt->setNumOp(initExt->getNumOp() + 1);
820     initExt->pushRegisteredOps(opName);
821     InFlightDiagnostic diag = emitRemark()
822                               << "Number of currently registered op: "
823                               << initExt->getNumOp() << "\n"
824                               << initExt->printMessage() << "\n";
825   }
826   return DiagnosedSilenceableFailure::success();
827 }
828 
829 namespace {
830 /// Test conversion pattern that replaces ops with the "replace_with_new_op"
831 /// attribute with "test.new_op".
832 class ReplaceWithNewOpConversion : public ConversionPattern {
833 public:
834   ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context)
835       : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(),
836                           /*benefit=*/1, context) {}
837 
838   LogicalResult
839   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
840                   ConversionPatternRewriter &rewriter) const override {
841     if (!op->hasAttr("replace_with_new_op"))
842       return failure();
843     SmallVector<Type> newResultTypes;
844     if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
845                                                 newResultTypes)))
846       return failure();
847     Operation *newOp = rewriter.create(
848         op->getLoc(),
849         OperationName("test.new_op", op->getContext()).getIdentifier(),
850         operands, newResultTypes);
851     rewriter.replaceOp(op, newOp->getResults());
852     return success();
853   }
854 };
855 } // namespace
856 
857 void mlir::test::ApplyTestConversionPatternsOp::populatePatterns(
858     TypeConverter &typeConverter, RewritePatternSet &patterns) {
859   patterns.insert<ReplaceWithNewOpConversion>(typeConverter,
860                                               patterns.getContext());
861 }
862 
863 namespace {
864 /// Test type converter that converts tensor types to memref types.
865 class TestTypeConverter : public TypeConverter {
866 public:
867   TestTypeConverter() {
868     addConversion([](Type t) { return t; });
869     addConversion([](RankedTensorType type) -> Type {
870       return MemRefType::get(type.getShape(), type.getElementType());
871     });
872     auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType,
873                                        ValueRange inputs,
874                                        Location loc) -> Value {
875       if (inputs.size() != 1)
876         return Value();
877       return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
878           .getResult(0);
879     };
880     addSourceMaterialization(unrealizedCastConverter);
881     addTargetMaterialization(unrealizedCastConverter);
882   }
883 };
884 } // namespace
885 
886 std::unique_ptr<::mlir::TypeConverter>
887 mlir::test::TestTypeConverterOp::getTypeConverter() {
888   return std::make_unique<TestTypeConverter>();
889 }
890 
891 namespace {
892 /// Test extension of the Transform dialect. Registers additional ops and
893 /// declares PDL as dependent dialect since the additional ops are using PDL
894 /// types for operands and results.
895 class TestTransformDialectExtension
896     : public transform::TransformDialectExtension<
897           TestTransformDialectExtension> {
898 public:
899   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension)
900 
901   using Base::Base;
902 
903   void init() {
904     declareDependentDialect<pdl::PDLDialect>();
905     registerTransformOps<TestTransformOp,
906                          TestTransformUnrestrictedOpNoInterface,
907 #define GET_OP_LIST
908 #include "TestTransformDialectExtension.cpp.inc"
909                          >();
910     registerTypes<
911 #define GET_TYPEDEF_LIST
912 #include "TestTransformDialectExtensionTypes.cpp.inc"
913         >();
914 
915     auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
916                                 ArrayRef<PDLValue> pdlValues) {
917       for (const PDLValue &pdlValue : pdlValues) {
918         if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
919           op->emitWarning() << "from PDL constraint";
920         }
921       }
922       return success();
923     };
924 
925     addDialectDataInitializer<transform::PDLMatchHooks>(
926         [&](transform::PDLMatchHooks &hooks) {
927           llvm::StringMap<PDLConstraintFunction> constraints;
928           constraints.try_emplace("verbose_constraint", verboseConstraint);
929           hooks.mergeInPDLMatchHooks(std::move(constraints));
930         });
931   }
932 };
933 } // namespace
934 
935 // These are automatically generated by ODS but are not used as the Transform
936 // dialect uses a different dispatch mechanism to support dialect extensions.
937 LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
938 generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
939 LLVM_ATTRIBUTE_UNUSED static LogicalResult
940 generatedTypePrinter(Type def, AsmPrinter &printer);
941 
942 #define GET_TYPEDEF_CLASSES
943 #include "TestTransformDialectExtensionTypes.cpp.inc"
944 
945 #define GET_OP_CLASSES
946 #include "TestTransformDialectExtension.cpp.inc"
947 
948 void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
949   registry.addExtensions<TestTransformDialectExtension>();
950 }
951