1 //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines Transform dialect extension operations used in the 10 // Chapter 3 of the Transform dialect tutorial. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "MyExtension.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/Interfaces/CallInterfaces.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 22 #define GET_TYPEDEF_CLASSES 23 #include "MyExtensionTypes.cpp.inc" 24 25 #define GET_OP_CLASSES 26 #include "MyExtension.cpp.inc" 27 28 //===---------------------------------------------------------------------===// 29 // MyExtension 30 //===---------------------------------------------------------------------===// 31 32 // Define a new transform dialect extension. This uses the CRTP idiom to 33 // identify extensions. 34 class MyExtension 35 : public ::mlir::transform::TransformDialectExtension<MyExtension> { 36 public: 37 // The extension must derive the base constructor. 38 using Base::Base; 39 40 // This function initializes the extension, similarly to `initialize` in 41 // dialect definitions. List individual operations and dependent dialects 42 // here. 43 void init(); 44 }; 45 46 void MyExtension::init() { 47 // Similarly to dialects, an extension can declare a dependent dialect. This 48 // dialect will be loaded along with the extension and, therefore, along with 49 // the Transform dialect. Only declare as dependent the dialects that contain 50 // the attributes or types used by transform operations. Do NOT declare as 51 // dependent the dialects produced during the transformation. 52 // declareDependentDialect<MyDialect>(); 53 54 // When transformations are applied, they may produce new operations from 55 // previously unloaded dialects. Typically, a pass would need to declare 56 // itself dependent on the dialects containing such new operations. To avoid 57 // confusion with the dialects the extension itself depends on, the Transform 58 // dialects differentiates between: 59 // - dependent dialects, which are used by the transform operations, and 60 // - generated dialects, which contain the entities (attributes, operations, 61 // types) that may be produced by applying the transformation even when 62 // not present in the original payload IR. 63 // In the following chapter, we will be add operations that generate function 64 // calls and structured control flow operations, so let's declare the 65 // corresponding dialects as generated. 66 declareGeneratedDialect<::mlir::scf::SCFDialect>(); 67 declareGeneratedDialect<::mlir::func::FuncDialect>(); 68 69 // Register the additional transform dialect types with the dialect. List all 70 // types generated from ODS. 71 registerTypes< 72 #define GET_TYPEDEF_LIST 73 #include "MyExtensionTypes.cpp.inc" 74 >(); 75 76 // ODS generates these helpers for type printing and parsing, but the 77 // Transform dialect provides its own support for types supplied by the 78 // extension. Reference these functions to avoid a compiler warning. 79 (void)generatedTypeParser; 80 (void)generatedTypePrinter; 81 82 // Finally, we register the additional transform operations with the dialect. 83 // List all operations generated from ODS. This call will perform additional 84 // checks that the operations implement the transform and memory effect 85 // interfaces required by the dialect interpreter and assert if they do not. 86 registerTransformOps< 87 #define GET_OP_LIST 88 #include "MyExtension.cpp.inc" 89 >(); 90 } 91 92 //===---------------------------------------------------------------------===// 93 // ChangeCallTargetOp 94 //===---------------------------------------------------------------------===// 95 96 static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { 97 call.setCallee(newTarget); 98 } 99 100 // Implementation of our transform dialect operation. 101 // This operation returns a tri-state result that can be one of: 102 // - success when the transformation succeeded; 103 // - definite failure when the transformation failed in such a way that 104 // following 105 // transformations are impossible or undesirable, typically it could have left 106 // payload IR in an invalid state; it is expected that a diagnostic is emitted 107 // immediately before returning the definite error; 108 // - silenceable failure when the transformation failed but following 109 // transformations 110 // are still applicable, typically this means a precondition for the 111 // transformation is not satisfied and the payload IR has not been modified. 112 // The silenceable failure additionally carries a Diagnostic that can be emitted 113 // to the user. 114 ::mlir::DiagnosedSilenceableFailure 115 mlir::transform::ChangeCallTargetOp::applyToOne( 116 // The single payload operation to which the transformation is applied. 117 ::mlir::func::CallOp call, 118 // The payload IR entities that will be appended to lists associated with 119 // the results of this transform operation. This list contains one entry per 120 // result. 121 ::mlir::transform::ApplyToEachResultList &results, 122 // The transform application state. This object can be used to query the 123 // current associations between transform IR values and payload IR entities. 124 // It can also carry additional user-defined state. 125 ::mlir::transform::TransformState &state) { 126 127 // Dispatch to the actual transformation. 128 updateCallee(call, getNewTarget()); 129 130 // If everything went well, return success. 131 return DiagnosedSilenceableFailure::success(); 132 } 133 134 void mlir::transform::ChangeCallTargetOp::getEffects( 135 ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { 136 // Indicate that the `call` handle is only read by this operation because the 137 // associated operation is not erased but rather modified in-place, so the 138 // reference to it remains valid. 139 onlyReadsHandle(getCall(), effects); 140 141 // Indicate that the payload is modified by this operation. 142 modifiesPayload(effects); 143 } 144 145 //===---------------------------------------------------------------------===// 146 // CallToOp 147 //===---------------------------------------------------------------------===// 148 149 static mlir::Operation *replaceCallWithOp(mlir::CallOpInterface call) { 150 // Construct an operation from an unregistered dialect. This is discouraged 151 // and is only used here for brevity of the overall example. 152 mlir::OperationState state(call.getLoc(), "my.mm4"); 153 state.types.assign(call->result_type_begin(), call->result_type_end()); 154 state.operands.assign(call->operand_begin(), call->operand_end()); 155 156 mlir::OpBuilder builder(call); 157 mlir::Operation *replacement = builder.create(state); 158 call->replaceAllUsesWith(replacement->getResults()); 159 call->erase(); 160 return replacement; 161 } 162 163 // See above for the signature description. 164 mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne( 165 mlir::CallOpInterface call, mlir::transform::ApplyToEachResultList &results, 166 mlir::transform::TransformState &state) { 167 168 // Dispatch to the actual transformation. 169 Operation *replacement = replaceCallWithOp(call); 170 171 // Associate the payload operation produced by the rewrite with the result 172 // handle of this transform operation. 173 results.push_back(replacement); 174 175 // If everything went well, return success. 176 return DiagnosedSilenceableFailure::success(); 177 } 178 179 //===---------------------------------------------------------------------===// 180 // CallOpInterfaceHandleType 181 //===---------------------------------------------------------------------===// 182 183 // The interface declares this method to verify constraints this type has on 184 // payload operations. It returns the now familiar tri-state result. 185 mlir::DiagnosedSilenceableFailure 186 mlir::transform::CallOpInterfaceHandleType::checkPayload( 187 // Location at which diagnostics should be emitted. 188 mlir::Location loc, 189 // List of payload operations that are about to be associated with the 190 // handle that has this type. 191 llvm::ArrayRef<mlir::Operation *> payload) const { 192 193 // All payload operations are expected to implement CallOpInterface, check 194 // this. 195 for (Operation *op : payload) { 196 if (llvm::isa<mlir::CallOpInterface>(op)) 197 continue; 198 199 // By convention, these verifiers always emit a silenceable failure since 200 // they are checking a precondition. 201 DiagnosedSilenceableFailure diag = 202 emitSilenceableError(loc) 203 << "expected the payload operation to implement CallOpInterface"; 204 diag.attachNote(op->getLoc()) << "offending operation"; 205 return diag; 206 } 207 208 // If everything is okay, return success. 209 return DiagnosedSilenceableFailure::success(); 210 } 211 212 //===---------------------------------------------------------------------===// 213 // Extension registration 214 //===---------------------------------------------------------------------===// 215 216 void registerMyExtension(::mlir::DialectRegistry ®istry) { 217 registry.addExtensions<MyExtension>(); 218 } 219