1 //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines Transform dialect extension operations used in the 10 // Chapter 3 of the Transform dialect tutorial. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "MyExtension.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/Interfaces/CallInterfaces.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 #define GET_TYPEDEF_CLASSES 24 #include "MyExtensionTypes.cpp.inc" 25 26 #define GET_OP_CLASSES 27 #include "MyExtension.cpp.inc" 28 29 //===---------------------------------------------------------------------===// 30 // MyExtension 31 //===---------------------------------------------------------------------===// 32 33 // Define a new transform dialect extension. This uses the CRTP idiom to 34 // identify extensions. 35 class MyExtension 36 : public ::mlir::transform::TransformDialectExtension<MyExtension> { 37 public: 38 // The TypeID of this extension. 39 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) 40 41 // The extension must derive the base constructor. 42 using Base::Base; 43 44 // This function initializes the extension, similarly to `initialize` in 45 // dialect definitions. List individual operations and dependent dialects 46 // here. 47 void init(); 48 }; 49 50 void MyExtension::init() { 51 // Similarly to dialects, an extension can declare a dependent dialect. This 52 // dialect will be loaded along with the extension and, therefore, along with 53 // the Transform dialect. Only declare as dependent the dialects that contain 54 // the attributes or types used by transform operations. Do NOT declare as 55 // dependent the dialects produced during the transformation. 56 // declareDependentDialect<MyDialect>(); 57 58 // When transformations are applied, they may produce new operations from 59 // previously unloaded dialects. Typically, a pass would need to declare 60 // itself dependent on the dialects containing such new operations. To avoid 61 // confusion with the dialects the extension itself depends on, the Transform 62 // dialects differentiates between: 63 // - dependent dialects, which are used by the transform operations, and 64 // - generated dialects, which contain the entities (attributes, operations, 65 // types) that may be produced by applying the transformation even when 66 // not present in the original payload IR. 67 // In the following chapter, we will be add operations that generate function 68 // calls and structured control flow operations, so let's declare the 69 // corresponding dialects as generated. 70 declareGeneratedDialect<::mlir::scf::SCFDialect>(); 71 declareGeneratedDialect<::mlir::func::FuncDialect>(); 72 73 // Register the additional transform dialect types with the dialect. List all 74 // types generated from ODS. 75 registerTypes< 76 #define GET_TYPEDEF_LIST 77 #include "MyExtensionTypes.cpp.inc" 78 >(); 79 80 // ODS generates these helpers for type printing and parsing, but the 81 // Transform dialect provides its own support for types supplied by the 82 // extension. Reference these functions to avoid a compiler warning. 83 (void)&generatedTypeParser; 84 (void)&generatedTypePrinter; 85 86 // Finally, we register the additional transform operations with the dialect. 87 // List all operations generated from ODS. This call will perform additional 88 // checks that the operations implement the transform and memory effect 89 // interfaces required by the dialect interpreter and assert if they do not. 90 registerTransformOps< 91 #define GET_OP_LIST 92 #include "MyExtension.cpp.inc" 93 >(); 94 } 95 96 //===---------------------------------------------------------------------===// 97 // ChangeCallTargetOp 98 //===---------------------------------------------------------------------===// 99 100 static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { 101 call.setCallee(newTarget); 102 } 103 104 // Implementation of our transform dialect operation. 105 // This operation returns a tri-state result that can be one of: 106 // - success when the transformation succeeded; 107 // - definite failure when the transformation failed in such a way that 108 // following 109 // transformations are impossible or undesirable, typically it could have left 110 // payload IR in an invalid state; it is expected that a diagnostic is emitted 111 // immediately before returning the definite error; 112 // - silenceable failure when the transformation failed but following 113 // transformations 114 // are still applicable, typically this means a precondition for the 115 // transformation is not satisfied and the payload IR has not been modified. 116 // The silenceable failure additionally carries a Diagnostic that can be emitted 117 // to the user. 118 ::mlir::DiagnosedSilenceableFailure 119 mlir::transform::ChangeCallTargetOp::applyToOne( 120 // The rewriter that should be used when modifying IR. 121 ::mlir::transform::TransformRewriter &rewriter, 122 // The single payload operation to which the transformation is applied. 123 ::mlir::func::CallOp call, 124 // The payload IR entities that will be appended to lists associated with 125 // the results of this transform operation. This list contains one entry per 126 // result. 127 ::mlir::transform::ApplyToEachResultList &results, 128 // The transform application state. This object can be used to query the 129 // current associations between transform IR values and payload IR entities. 130 // It can also carry additional user-defined state. 131 ::mlir::transform::TransformState &state) { 132 133 // Dispatch to the actual transformation. 134 updateCallee(call, getNewTarget()); 135 136 // If everything went well, return success. 137 return DiagnosedSilenceableFailure::success(); 138 } 139 140 void mlir::transform::ChangeCallTargetOp::getEffects( 141 ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { 142 // Indicate that the `call` handle is only read by this operation because the 143 // associated operation is not erased but rather modified in-place, so the 144 // reference to it remains valid. 145 onlyReadsHandle(getCallMutable(), effects); 146 147 // Indicate that the payload is modified by this operation. 148 modifiesPayload(effects); 149 } 150 151 //===---------------------------------------------------------------------===// 152 // CallToOp 153 //===---------------------------------------------------------------------===// 154 155 static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter, 156 mlir::CallOpInterface call) { 157 // Construct an operation from an unregistered dialect. This is discouraged 158 // and is only used here for brevity of the overall example. 159 mlir::OperationState state(call.getLoc(), "my.mm4"); 160 state.types.assign(call->result_type_begin(), call->result_type_end()); 161 state.operands.assign(call->operand_begin(), call->operand_end()); 162 163 mlir::Operation *replacement = rewriter.create(state); 164 rewriter.replaceOp(call, replacement->getResults()); 165 return replacement; 166 } 167 168 // See above for the signature description. 169 mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne( 170 mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call, 171 mlir::transform::ApplyToEachResultList &results, 172 mlir::transform::TransformState &state) { 173 174 // Dispatch to the actual transformation. 175 Operation *replacement = replaceCallWithOp(rewriter, call); 176 177 // Associate the payload operation produced by the rewrite with the result 178 // handle of this transform operation. 179 results.push_back(replacement); 180 181 // If everything went well, return success. 182 return DiagnosedSilenceableFailure::success(); 183 } 184 185 //===---------------------------------------------------------------------===// 186 // CallOpInterfaceHandleType 187 //===---------------------------------------------------------------------===// 188 189 // The interface declares this method to verify constraints this type has on 190 // payload operations. It returns the now familiar tri-state result. 191 mlir::DiagnosedSilenceableFailure 192 mlir::transform::CallOpInterfaceHandleType::checkPayload( 193 // Location at which diagnostics should be emitted. 194 mlir::Location loc, 195 // List of payload operations that are about to be associated with the 196 // handle that has this type. 197 llvm::ArrayRef<mlir::Operation *> payload) const { 198 199 // All payload operations are expected to implement CallOpInterface, check 200 // this. 201 for (Operation *op : payload) { 202 if (llvm::isa<mlir::CallOpInterface>(op)) 203 continue; 204 205 // By convention, these verifiers always emit a silenceable failure since 206 // they are checking a precondition. 207 DiagnosedSilenceableFailure diag = 208 emitSilenceableError(loc) 209 << "expected the payload operation to implement CallOpInterface"; 210 diag.attachNote(op->getLoc()) << "offending operation"; 211 return diag; 212 } 213 214 // If everything is okay, return success. 215 return DiagnosedSilenceableFailure::success(); 216 } 217 218 //===---------------------------------------------------------------------===// 219 // Extension registration 220 //===---------------------------------------------------------------------===// 221 222 void registerMyExtension(::mlir::DialectRegistry ®istry) { 223 registry.addExtensions<MyExtension>(); 224 } 225