xref: /llvm-project/mlir/examples/transform/Ch3/lib/MyExtension.cpp (revision 68ae0d7803e43146b28f94f62357226047af7d9a)
1 //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines Transform dialect extension operations used in the
10 // Chapter 3 of the Transform dialect tutorial.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "MyExtension.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/Interfaces/CallInterfaces.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 #define GET_TYPEDEF_CLASSES
23 #include "MyExtensionTypes.cpp.inc"
24 
25 #define GET_OP_CLASSES
26 #include "MyExtension.cpp.inc"
27 
28 //===---------------------------------------------------------------------===//
29 // MyExtension
30 //===---------------------------------------------------------------------===//
31 
32 // Define a new transform dialect extension. This uses the CRTP idiom to
33 // identify extensions.
34 class MyExtension
35     : public ::mlir::transform::TransformDialectExtension<MyExtension> {
36 public:
37   // The extension must derive the base constructor.
38   using Base::Base;
39 
40   // This function initializes the extension, similarly to `initialize` in
41   // dialect definitions. List individual operations and dependent dialects
42   // here.
43   void init();
44 };
45 
46 void MyExtension::init() {
47   // Similarly to dialects, an extension can declare a dependent dialect. This
48   // dialect will be loaded along with the extension and, therefore, along with
49   // the Transform dialect. Only declare as dependent the dialects that contain
50   // the attributes or types used by transform operations. Do NOT declare as
51   // dependent the dialects produced during the transformation.
52   // declareDependentDialect<MyDialect>();
53 
54   // When transformations are applied, they may produce new operations from
55   // previously unloaded dialects. Typically, a pass would need to declare
56   // itself dependent on the dialects containing such new operations. To avoid
57   // confusion with the dialects the extension itself depends on, the Transform
58   // dialects differentiates between:
59   //   - dependent dialects, which are used by the transform operations, and
60   //   - generated dialects, which contain the entities (attributes, operations,
61   //     types) that may be produced by applying the transformation even when
62   //     not present in the original payload IR.
63   // In the following chapter, we will be add operations that generate function
64   // calls and structured control flow operations, so let's declare the
65   // corresponding dialects as generated.
66   declareGeneratedDialect<::mlir::scf::SCFDialect>();
67   declareGeneratedDialect<::mlir::func::FuncDialect>();
68 
69   // Register the additional transform dialect types with the dialect. List all
70   // types generated from ODS.
71   registerTypes<
72 #define GET_TYPEDEF_LIST
73 #include "MyExtensionTypes.cpp.inc"
74       >();
75 
76   // ODS generates these helpers for type printing and parsing, but the
77   // Transform dialect provides its own support for types supplied by the
78   // extension. Reference these functions to avoid a compiler warning.
79   (void)generatedTypeParser;
80   (void)generatedTypePrinter;
81 
82   // Finally, we register the additional transform operations with the dialect.
83   // List all operations generated from ODS. This call will perform additional
84   // checks that the operations implement the transform and memory effect
85   // interfaces required by the dialect interpreter and assert if they do not.
86   registerTransformOps<
87 #define GET_OP_LIST
88 #include "MyExtension.cpp.inc"
89       >();
90 }
91 
92 //===---------------------------------------------------------------------===//
93 // ChangeCallTargetOp
94 //===---------------------------------------------------------------------===//
95 
96 static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) {
97   call.setCallee(newTarget);
98 }
99 
100 // Implementation of our transform dialect operation.
101 // This operation returns a tri-state result that can be one of:
102 // - success when the transformation succeeded;
103 // - definite failure when the transformation failed in such a way that
104 // following
105 //   transformations are impossible or undesirable, typically it could have left
106 //   payload IR in an invalid state; it is expected that a diagnostic is emitted
107 //   immediately before returning the definite error;
108 // - silenceable failure when the transformation failed but following
109 // transformations
110 //   are still applicable, typically this means a precondition for the
111 //   transformation is not satisfied and the payload IR has not been modified.
112 // The silenceable failure additionally carries a Diagnostic that can be emitted
113 // to the user.
114 ::mlir::DiagnosedSilenceableFailure
115 mlir::transform::ChangeCallTargetOp::applyToOne(
116     // The single payload operation to which the transformation is applied.
117     ::mlir::func::CallOp call,
118     // The payload IR entities that will be appended to lists associated with
119     // the results of this transform operation. This list contains one entry per
120     // result.
121     ::mlir::transform::ApplyToEachResultList &results,
122     // The transform application state. This object can be used to query the
123     // current associations between transform IR values and payload IR entities.
124     // It can also carry additional user-defined state.
125     ::mlir::transform::TransformState &state) {
126 
127   // Dispatch to the actual transformation.
128   updateCallee(call, getNewTarget());
129 
130   // If everything went well, return success.
131   return DiagnosedSilenceableFailure::success();
132 }
133 
134 void mlir::transform::ChangeCallTargetOp::getEffects(
135     ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
136   // Indicate that the `call` handle is only read by this operation because the
137   // associated operation is not erased but rather modified in-place, so the
138   // reference to it remains valid.
139   onlyReadsHandle(getCall(), effects);
140 
141   // Indicate that the payload is modified by this operation.
142   modifiesPayload(effects);
143 }
144 
145 //===---------------------------------------------------------------------===//
146 // CallToOp
147 //===---------------------------------------------------------------------===//
148 
149 static mlir::Operation *replaceCallWithOp(mlir::CallOpInterface call) {
150   // Construct an operation from an unregistered dialect. This is discouraged
151   // and is only used here for brevity of the overall example.
152   mlir::OperationState state(call.getLoc(), "my.mm4");
153   state.types.assign(call->result_type_begin(), call->result_type_end());
154   state.operands.assign(call->operand_begin(), call->operand_end());
155 
156   mlir::OpBuilder builder(call);
157   mlir::Operation *replacement = builder.create(state);
158   call->replaceAllUsesWith(replacement->getResults());
159   call->erase();
160   return replacement;
161 }
162 
163 // See above for the signature description.
164 mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne(
165     mlir::CallOpInterface call, mlir::transform::ApplyToEachResultList &results,
166     mlir::transform::TransformState &state) {
167 
168   // Dispatch to the actual transformation.
169   Operation *replacement = replaceCallWithOp(call);
170 
171   // Associate the payload operation produced by the rewrite with the result
172   // handle of this transform operation.
173   results.push_back(replacement);
174 
175   // If everything went well, return success.
176   return DiagnosedSilenceableFailure::success();
177 }
178 
179 //===---------------------------------------------------------------------===//
180 // CallOpInterfaceHandleType
181 //===---------------------------------------------------------------------===//
182 
183 // The interface declares this method to verify constraints this type has on
184 // payload operations. It returns the now familiar tri-state result.
185 mlir::DiagnosedSilenceableFailure
186 mlir::transform::CallOpInterfaceHandleType::checkPayload(
187     // Location at which diagnostics should be emitted.
188     mlir::Location loc,
189     // List of payload operations that are about to be associated with the
190     // handle that has this type.
191     llvm::ArrayRef<mlir::Operation *> payload) const {
192 
193   // All payload operations are expected to implement CallOpInterface, check
194   // this.
195   for (Operation *op : payload) {
196     if (llvm::isa<mlir::CallOpInterface>(op))
197       continue;
198 
199     // By convention, these verifiers always emit a silenceable failure since
200     // they are checking a precondition.
201     DiagnosedSilenceableFailure diag =
202         emitSilenceableError(loc)
203         << "expected the payload operation to implement CallOpInterface";
204     diag.attachNote(op->getLoc()) << "offending operation";
205     return diag;
206   }
207 
208   // If everything is okay, return success.
209   return DiagnosedSilenceableFailure::success();
210 }
211 
212 //===---------------------------------------------------------------------===//
213 // Extension registration
214 //===---------------------------------------------------------------------===//
215 
216 void registerMyExtension(::mlir::DialectRegistry &registry) {
217   registry.addExtensions<MyExtension>();
218 }
219