xref: /llvm-project/mlir/examples/transform/Ch3/lib/MyExtension.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1 //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines Transform dialect extension operations used in the
10 // Chapter 3 of the Transform dialect tutorial.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "MyExtension.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/Interfaces/CallInterfaces.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 #define GET_TYPEDEF_CLASSES
24 #include "MyExtensionTypes.cpp.inc"
25 
26 #define GET_OP_CLASSES
27 #include "MyExtension.cpp.inc"
28 
29 //===---------------------------------------------------------------------===//
30 // MyExtension
31 //===---------------------------------------------------------------------===//
32 
33 // Define a new transform dialect extension. This uses the CRTP idiom to
34 // identify extensions.
35 class MyExtension
36     : public ::mlir::transform::TransformDialectExtension<MyExtension> {
37 public:
38   // The TypeID of this extension.
39   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
40 
41   // The extension must derive the base constructor.
42   using Base::Base;
43 
44   // This function initializes the extension, similarly to `initialize` in
45   // dialect definitions. List individual operations and dependent dialects
46   // here.
47   void init();
48 };
49 
50 void MyExtension::init() {
51   // Similarly to dialects, an extension can declare a dependent dialect. This
52   // dialect will be loaded along with the extension and, therefore, along with
53   // the Transform dialect. Only declare as dependent the dialects that contain
54   // the attributes or types used by transform operations. Do NOT declare as
55   // dependent the dialects produced during the transformation.
56   // declareDependentDialect<MyDialect>();
57 
58   // When transformations are applied, they may produce new operations from
59   // previously unloaded dialects. Typically, a pass would need to declare
60   // itself dependent on the dialects containing such new operations. To avoid
61   // confusion with the dialects the extension itself depends on, the Transform
62   // dialects differentiates between:
63   //   - dependent dialects, which are used by the transform operations, and
64   //   - generated dialects, which contain the entities (attributes, operations,
65   //     types) that may be produced by applying the transformation even when
66   //     not present in the original payload IR.
67   // In the following chapter, we will be add operations that generate function
68   // calls and structured control flow operations, so let's declare the
69   // corresponding dialects as generated.
70   declareGeneratedDialect<::mlir::scf::SCFDialect>();
71   declareGeneratedDialect<::mlir::func::FuncDialect>();
72 
73   // Register the additional transform dialect types with the dialect. List all
74   // types generated from ODS.
75   registerTypes<
76 #define GET_TYPEDEF_LIST
77 #include "MyExtensionTypes.cpp.inc"
78       >();
79 
80   // ODS generates these helpers for type printing and parsing, but the
81   // Transform dialect provides its own support for types supplied by the
82   // extension. Reference these functions to avoid a compiler warning.
83   (void)&generatedTypeParser;
84   (void)&generatedTypePrinter;
85 
86   // Finally, we register the additional transform operations with the dialect.
87   // List all operations generated from ODS. This call will perform additional
88   // checks that the operations implement the transform and memory effect
89   // interfaces required by the dialect interpreter and assert if they do not.
90   registerTransformOps<
91 #define GET_OP_LIST
92 #include "MyExtension.cpp.inc"
93       >();
94 }
95 
96 //===---------------------------------------------------------------------===//
97 // ChangeCallTargetOp
98 //===---------------------------------------------------------------------===//
99 
100 static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) {
101   call.setCallee(newTarget);
102 }
103 
104 // Implementation of our transform dialect operation.
105 // This operation returns a tri-state result that can be one of:
106 // - success when the transformation succeeded;
107 // - definite failure when the transformation failed in such a way that
108 // following
109 //   transformations are impossible or undesirable, typically it could have left
110 //   payload IR in an invalid state; it is expected that a diagnostic is emitted
111 //   immediately before returning the definite error;
112 // - silenceable failure when the transformation failed but following
113 // transformations
114 //   are still applicable, typically this means a precondition for the
115 //   transformation is not satisfied and the payload IR has not been modified.
116 // The silenceable failure additionally carries a Diagnostic that can be emitted
117 // to the user.
118 ::mlir::DiagnosedSilenceableFailure
119 mlir::transform::ChangeCallTargetOp::applyToOne(
120     // The rewriter that should be used when modifying IR.
121     ::mlir::transform::TransformRewriter &rewriter,
122     // The single payload operation to which the transformation is applied.
123     ::mlir::func::CallOp call,
124     // The payload IR entities that will be appended to lists associated with
125     // the results of this transform operation. This list contains one entry per
126     // result.
127     ::mlir::transform::ApplyToEachResultList &results,
128     // The transform application state. This object can be used to query the
129     // current associations between transform IR values and payload IR entities.
130     // It can also carry additional user-defined state.
131     ::mlir::transform::TransformState &state) {
132 
133   // Dispatch to the actual transformation.
134   updateCallee(call, getNewTarget());
135 
136   // If everything went well, return success.
137   return DiagnosedSilenceableFailure::success();
138 }
139 
140 void mlir::transform::ChangeCallTargetOp::getEffects(
141     ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
142   // Indicate that the `call` handle is only read by this operation because the
143   // associated operation is not erased but rather modified in-place, so the
144   // reference to it remains valid.
145   onlyReadsHandle(getCallMutable(), effects);
146 
147   // Indicate that the payload is modified by this operation.
148   modifiesPayload(effects);
149 }
150 
151 //===---------------------------------------------------------------------===//
152 // CallToOp
153 //===---------------------------------------------------------------------===//
154 
155 static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter,
156                                           mlir::CallOpInterface call) {
157   // Construct an operation from an unregistered dialect. This is discouraged
158   // and is only used here for brevity of the overall example.
159   mlir::OperationState state(call.getLoc(), "my.mm4");
160   state.types.assign(call->result_type_begin(), call->result_type_end());
161   state.operands.assign(call->operand_begin(), call->operand_end());
162 
163   mlir::Operation *replacement = rewriter.create(state);
164   rewriter.replaceOp(call, replacement->getResults());
165   return replacement;
166 }
167 
168 // See above for the signature description.
169 mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne(
170     mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call,
171     mlir::transform::ApplyToEachResultList &results,
172     mlir::transform::TransformState &state) {
173 
174   // Dispatch to the actual transformation.
175   Operation *replacement = replaceCallWithOp(rewriter, call);
176 
177   // Associate the payload operation produced by the rewrite with the result
178   // handle of this transform operation.
179   results.push_back(replacement);
180 
181   // If everything went well, return success.
182   return DiagnosedSilenceableFailure::success();
183 }
184 
185 //===---------------------------------------------------------------------===//
186 // CallOpInterfaceHandleType
187 //===---------------------------------------------------------------------===//
188 
189 // The interface declares this method to verify constraints this type has on
190 // payload operations. It returns the now familiar tri-state result.
191 mlir::DiagnosedSilenceableFailure
192 mlir::transform::CallOpInterfaceHandleType::checkPayload(
193     // Location at which diagnostics should be emitted.
194     mlir::Location loc,
195     // List of payload operations that are about to be associated with the
196     // handle that has this type.
197     llvm::ArrayRef<mlir::Operation *> payload) const {
198 
199   // All payload operations are expected to implement CallOpInterface, check
200   // this.
201   for (Operation *op : payload) {
202     if (llvm::isa<mlir::CallOpInterface>(op))
203       continue;
204 
205     // By convention, these verifiers always emit a silenceable failure since
206     // they are checking a precondition.
207     DiagnosedSilenceableFailure diag =
208         emitSilenceableError(loc)
209         << "expected the payload operation to implement CallOpInterface";
210     diag.attachNote(op->getLoc()) << "offending operation";
211     return diag;
212   }
213 
214   // If everything is okay, return success.
215   return DiagnosedSilenceableFailure::success();
216 }
217 
218 //===---------------------------------------------------------------------===//
219 // Extension registration
220 //===---------------------------------------------------------------------===//
221 
222 void registerMyExtension(::mlir::DialectRegistry &registry) {
223   registry.addExtensions<MyExtension>();
224 }
225