xref: /llvm-project/mlir/examples/transform/Ch2/lib/MyExtension.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1 //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines Transform dialect extension operations used in the
10 // Chapter 2 of the Transform dialect tutorial.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "MyExtension.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
19 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20 #include "mlir/IR/DialectRegistry.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 #include "mlir/Support/LLVM.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 
27 // Define a new transform dialect extension. This uses the CRTP idiom to
28 // identify extensions.
29 class MyExtension
30     : public ::mlir::transform::TransformDialectExtension<MyExtension> {
31 public:
32   // The TypeID of this extension.
33   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
34 
35   // The extension must derive the base constructor.
36   using Base::Base;
37 
38   // This function initializes the extension, similarly to `initialize` in
39   // dialect definitions. List individual operations and dependent dialects
40   // here.
41   void init();
42 };
43 
44 void MyExtension::init() {
45   // Similarly to dialects, an extension can declare a dependent dialect. This
46   // dialect will be loaded along with the extension and, therefore, along with
47   // the Transform dialect. Only declare as dependent the dialects that contain
48   // the attributes or types used by transform operations. Do NOT declare as
49   // dependent the dialects produced during the transformation.
50   // declareDependentDialect<MyDialect>();
51 
52   // When transformations are applied, they may produce new operations from
53   // previously unloaded dialects. Typically, a pass would need to declare
54   // itself dependent on the dialects containing such new operations. To avoid
55   // confusion with the dialects the extension itself depends on, the Transform
56   // dialects differentiates between:
57   //   - dependent dialects, which are used by the transform operations, and
58   //   - generated dialects, which contain the entities (attributes, operations,
59   //     types) that may be produced by applying the transformation even when
60   //     not present in the original payload IR.
61   // In the following chapter, we will be add operations that generate function
62   // calls and structured control flow operations, so let's declare the
63   // corresponding dialects as generated.
64   declareGeneratedDialect<::mlir::scf::SCFDialect>();
65   declareGeneratedDialect<::mlir::func::FuncDialect>();
66 
67   // Finally, we register the additional transform operations with the dialect.
68   // List all operations generated from ODS. This call will perform additional
69   // checks that the operations implement the transform and memory effect
70   // interfaces required by the dialect interpreter and assert if they do not.
71   registerTransformOps<
72 #define GET_OP_LIST
73 #include "MyExtension.cpp.inc"
74       >();
75 }
76 
77 #define GET_OP_CLASSES
78 #include "MyExtension.cpp.inc"
79 
80 static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) {
81   call.setCallee(newTarget);
82 }
83 
84 // Implementation of our transform dialect operation.
85 // This operation returns a tri-state result that can be one of:
86 // - success when the transformation succeeded;
87 // - definite failure when the transformation failed in such a way that
88 //   following transformations are impossible or undesirable, typically it could
89 //   have left payload IR in an invalid state; it is expected that a diagnostic
90 //   is emitted immediately before returning the definite error;
91 // - silenceable failure when the transformation failed but following
92 //   transformations are still applicable, typically this means a precondition
93 //   for the transformation is not satisfied and the payload IR has not been
94 //   modified. The silenceable failure additionally carries a Diagnostic that
95 //   can be emitted to the user.
96 ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply(
97     // The rewriter that should be used when modifying IR.
98     ::mlir::transform::TransformRewriter &rewriter,
99     // The list of payload IR entities that will be associated with the
100     // transform IR values defined by this transform operation. In this case, it
101     // can remain empty as there are no results.
102     ::mlir::transform::TransformResults &results,
103     // The transform application state. This object can be used to query the
104     // current associations between transform IR values and payload IR entities.
105     // It can also carry additional user-defined state.
106     ::mlir::transform::TransformState &state) {
107 
108   // First, we need to obtain the list of payload operations that are associated
109   // with the operand handle.
110   auto payload = state.getPayloadOps(getCall());
111 
112   // Then, we iterate over the list of operands and call the actual IR-mutating
113   // function. We also check the preconditions here.
114   for (Operation *payloadOp : payload) {
115     auto call = dyn_cast<::mlir::func::CallOp>(payloadOp);
116     if (!call) {
117       DiagnosedSilenceableFailure diag =
118           emitSilenceableError() << "only applies to func.call payloads";
119       diag.attachNote(payloadOp->getLoc()) << "offending payload";
120       return diag;
121     }
122 
123     updateCallee(call, getNewTarget());
124   }
125 
126   // If everything went well, return success.
127   return DiagnosedSilenceableFailure::success();
128 }
129 
130 void mlir::transform::ChangeCallTargetOp::getEffects(
131     ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
132   // Indicate that the `call` handle is only read by this operation because the
133   // associated operation is not erased but rather modified in-place, so the
134   // reference to it remains valid.
135   onlyReadsHandle(getCallMutable(), effects);
136 
137   // Indicate that the payload is modified by this operation.
138   modifiesPayload(effects);
139 }
140 
141 void registerMyExtension(::mlir::DialectRegistry &registry) {
142   registry.addExtensions<MyExtension>();
143 }
144