xref: /llvm-project/mlir/examples/transform/Ch3/lib/MyExtension.cpp (revision 2c1ae801e1b66a09a15028ae4ba614e0911eec00)
1 //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines Transform dialect extension operations used in the
10 // Chapter 3 of the Transform dialect tutorial.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "MyExtension.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/Interfaces/CallInterfaces.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 #define GET_TYPEDEF_CLASSES
24 #include "MyExtensionTypes.cpp.inc"
25 
26 #define GET_OP_CLASSES
27 #include "MyExtension.cpp.inc"
28 
29 //===---------------------------------------------------------------------===//
30 // MyExtension
31 //===---------------------------------------------------------------------===//
32 
33 // Define a new transform dialect extension. This uses the CRTP idiom to
34 // identify extensions.
35 class MyExtension
36     : public ::mlir::transform::TransformDialectExtension<MyExtension> {
37 public:
38   // The extension must derive the base constructor.
39   using Base::Base;
40 
41   // This function initializes the extension, similarly to `initialize` in
42   // dialect definitions. List individual operations and dependent dialects
43   // here.
44   void init();
45 };
46 
47 void MyExtension::init() {
48   // Similarly to dialects, an extension can declare a dependent dialect. This
49   // dialect will be loaded along with the extension and, therefore, along with
50   // the Transform dialect. Only declare as dependent the dialects that contain
51   // the attributes or types used by transform operations. Do NOT declare as
52   // dependent the dialects produced during the transformation.
53   // declareDependentDialect<MyDialect>();
54 
55   // When transformations are applied, they may produce new operations from
56   // previously unloaded dialects. Typically, a pass would need to declare
57   // itself dependent on the dialects containing such new operations. To avoid
58   // confusion with the dialects the extension itself depends on, the Transform
59   // dialects differentiates between:
60   //   - dependent dialects, which are used by the transform operations, and
61   //   - generated dialects, which contain the entities (attributes, operations,
62   //     types) that may be produced by applying the transformation even when
63   //     not present in the original payload IR.
64   // In the following chapter, we will be add operations that generate function
65   // calls and structured control flow operations, so let's declare the
66   // corresponding dialects as generated.
67   declareGeneratedDialect<::mlir::scf::SCFDialect>();
68   declareGeneratedDialect<::mlir::func::FuncDialect>();
69 
70   // Register the additional transform dialect types with the dialect. List all
71   // types generated from ODS.
72   registerTypes<
73 #define GET_TYPEDEF_LIST
74 #include "MyExtensionTypes.cpp.inc"
75       >();
76 
77   // ODS generates these helpers for type printing and parsing, but the
78   // Transform dialect provides its own support for types supplied by the
79   // extension. Reference these functions to avoid a compiler warning.
80   (void)&generatedTypeParser;
81   (void)&generatedTypePrinter;
82 
83   // Finally, we register the additional transform operations with the dialect.
84   // List all operations generated from ODS. This call will perform additional
85   // checks that the operations implement the transform and memory effect
86   // interfaces required by the dialect interpreter and assert if they do not.
87   registerTransformOps<
88 #define GET_OP_LIST
89 #include "MyExtension.cpp.inc"
90       >();
91 }
92 
93 //===---------------------------------------------------------------------===//
94 // ChangeCallTargetOp
95 //===---------------------------------------------------------------------===//
96 
97 static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) {
98   call.setCallee(newTarget);
99 }
100 
101 // Implementation of our transform dialect operation.
102 // This operation returns a tri-state result that can be one of:
103 // - success when the transformation succeeded;
104 // - definite failure when the transformation failed in such a way that
105 // following
106 //   transformations are impossible or undesirable, typically it could have left
107 //   payload IR in an invalid state; it is expected that a diagnostic is emitted
108 //   immediately before returning the definite error;
109 // - silenceable failure when the transformation failed but following
110 // transformations
111 //   are still applicable, typically this means a precondition for the
112 //   transformation is not satisfied and the payload IR has not been modified.
113 // The silenceable failure additionally carries a Diagnostic that can be emitted
114 // to the user.
115 ::mlir::DiagnosedSilenceableFailure
116 mlir::transform::ChangeCallTargetOp::applyToOne(
117     // The rewriter that should be used when modifying IR.
118     ::mlir::transform::TransformRewriter &rewriter,
119     // The single payload operation to which the transformation is applied.
120     ::mlir::func::CallOp call,
121     // The payload IR entities that will be appended to lists associated with
122     // the results of this transform operation. This list contains one entry per
123     // result.
124     ::mlir::transform::ApplyToEachResultList &results,
125     // The transform application state. This object can be used to query the
126     // current associations between transform IR values and payload IR entities.
127     // It can also carry additional user-defined state.
128     ::mlir::transform::TransformState &state) {
129 
130   // Dispatch to the actual transformation.
131   updateCallee(call, getNewTarget());
132 
133   // If everything went well, return success.
134   return DiagnosedSilenceableFailure::success();
135 }
136 
137 void mlir::transform::ChangeCallTargetOp::getEffects(
138     ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
139   // Indicate that the `call` handle is only read by this operation because the
140   // associated operation is not erased but rather modified in-place, so the
141   // reference to it remains valid.
142   onlyReadsHandle(getCallMutable(), effects);
143 
144   // Indicate that the payload is modified by this operation.
145   modifiesPayload(effects);
146 }
147 
148 //===---------------------------------------------------------------------===//
149 // CallToOp
150 //===---------------------------------------------------------------------===//
151 
152 static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter,
153                                           mlir::CallOpInterface call) {
154   // Construct an operation from an unregistered dialect. This is discouraged
155   // and is only used here for brevity of the overall example.
156   mlir::OperationState state(call.getLoc(), "my.mm4");
157   state.types.assign(call->result_type_begin(), call->result_type_end());
158   state.operands.assign(call->operand_begin(), call->operand_end());
159 
160   mlir::Operation *replacement = rewriter.create(state);
161   rewriter.replaceOp(call, replacement->getResults());
162   return replacement;
163 }
164 
165 // See above for the signature description.
166 mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne(
167     mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call,
168     mlir::transform::ApplyToEachResultList &results,
169     mlir::transform::TransformState &state) {
170 
171   // Dispatch to the actual transformation.
172   Operation *replacement = replaceCallWithOp(rewriter, call);
173 
174   // Associate the payload operation produced by the rewrite with the result
175   // handle of this transform operation.
176   results.push_back(replacement);
177 
178   // If everything went well, return success.
179   return DiagnosedSilenceableFailure::success();
180 }
181 
182 //===---------------------------------------------------------------------===//
183 // CallOpInterfaceHandleType
184 //===---------------------------------------------------------------------===//
185 
186 // The interface declares this method to verify constraints this type has on
187 // payload operations. It returns the now familiar tri-state result.
188 mlir::DiagnosedSilenceableFailure
189 mlir::transform::CallOpInterfaceHandleType::checkPayload(
190     // Location at which diagnostics should be emitted.
191     mlir::Location loc,
192     // List of payload operations that are about to be associated with the
193     // handle that has this type.
194     llvm::ArrayRef<mlir::Operation *> payload) const {
195 
196   // All payload operations are expected to implement CallOpInterface, check
197   // this.
198   for (Operation *op : payload) {
199     if (llvm::isa<mlir::CallOpInterface>(op))
200       continue;
201 
202     // By convention, these verifiers always emit a silenceable failure since
203     // they are checking a precondition.
204     DiagnosedSilenceableFailure diag =
205         emitSilenceableError(loc)
206         << "expected the payload operation to implement CallOpInterface";
207     diag.attachNote(op->getLoc()) << "offending operation";
208     return diag;
209   }
210 
211   // If everything is okay, return success.
212   return DiagnosedSilenceableFailure::success();
213 }
214 
215 //===---------------------------------------------------------------------===//
216 // Extension registration
217 //===---------------------------------------------------------------------===//
218 
219 void registerMyExtension(::mlir::DialectRegistry &registry) {
220   registry.addExtensions<MyExtension>();
221 }
222