1 //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines Transform dialect extension operations used in the 10 // Chapter 2 of the Transform dialect tutorial. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "MyExtension.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformTypes.h" 19 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 20 #include "mlir/IR/DialectRegistry.h" 21 #include "mlir/IR/Operation.h" 22 #include "mlir/Interfaces/SideEffectInterfaces.h" 23 #include "mlir/Support/LLVM.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/ADT/StringRef.h" 26 27 // Define a new transform dialect extension. This uses the CRTP idiom to 28 // identify extensions. 29 class MyExtension 30 : public ::mlir::transform::TransformDialectExtension<MyExtension> { 31 public: 32 // The TypeID of this extension. 33 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) 34 35 // The extension must derive the base constructor. 36 using Base::Base; 37 38 // This function initializes the extension, similarly to `initialize` in 39 // dialect definitions. List individual operations and dependent dialects 40 // here. 41 void init(); 42 }; 43 44 void MyExtension::init() { 45 // Similarly to dialects, an extension can declare a dependent dialect. This 46 // dialect will be loaded along with the extension and, therefore, along with 47 // the Transform dialect. Only declare as dependent the dialects that contain 48 // the attributes or types used by transform operations. Do NOT declare as 49 // dependent the dialects produced during the transformation. 50 // declareDependentDialect<MyDialect>(); 51 52 // When transformations are applied, they may produce new operations from 53 // previously unloaded dialects. Typically, a pass would need to declare 54 // itself dependent on the dialects containing such new operations. To avoid 55 // confusion with the dialects the extension itself depends on, the Transform 56 // dialects differentiates between: 57 // - dependent dialects, which are used by the transform operations, and 58 // - generated dialects, which contain the entities (attributes, operations, 59 // types) that may be produced by applying the transformation even when 60 // not present in the original payload IR. 61 // In the following chapter, we will be add operations that generate function 62 // calls and structured control flow operations, so let's declare the 63 // corresponding dialects as generated. 64 declareGeneratedDialect<::mlir::scf::SCFDialect>(); 65 declareGeneratedDialect<::mlir::func::FuncDialect>(); 66 67 // Finally, we register the additional transform operations with the dialect. 68 // List all operations generated from ODS. This call will perform additional 69 // checks that the operations implement the transform and memory effect 70 // interfaces required by the dialect interpreter and assert if they do not. 71 registerTransformOps< 72 #define GET_OP_LIST 73 #include "MyExtension.cpp.inc" 74 >(); 75 } 76 77 #define GET_OP_CLASSES 78 #include "MyExtension.cpp.inc" 79 80 static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { 81 call.setCallee(newTarget); 82 } 83 84 // Implementation of our transform dialect operation. 85 // This operation returns a tri-state result that can be one of: 86 // - success when the transformation succeeded; 87 // - definite failure when the transformation failed in such a way that 88 // following transformations are impossible or undesirable, typically it could 89 // have left payload IR in an invalid state; it is expected that a diagnostic 90 // is emitted immediately before returning the definite error; 91 // - silenceable failure when the transformation failed but following 92 // transformations are still applicable, typically this means a precondition 93 // for the transformation is not satisfied and the payload IR has not been 94 // modified. The silenceable failure additionally carries a Diagnostic that 95 // can be emitted to the user. 96 ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( 97 // The rewriter that should be used when modifying IR. 98 ::mlir::transform::TransformRewriter &rewriter, 99 // The list of payload IR entities that will be associated with the 100 // transform IR values defined by this transform operation. In this case, it 101 // can remain empty as there are no results. 102 ::mlir::transform::TransformResults &results, 103 // The transform application state. This object can be used to query the 104 // current associations between transform IR values and payload IR entities. 105 // It can also carry additional user-defined state. 106 ::mlir::transform::TransformState &state) { 107 108 // First, we need to obtain the list of payload operations that are associated 109 // with the operand handle. 110 auto payload = state.getPayloadOps(getCall()); 111 112 // Then, we iterate over the list of operands and call the actual IR-mutating 113 // function. We also check the preconditions here. 114 for (Operation *payloadOp : payload) { 115 auto call = dyn_cast<::mlir::func::CallOp>(payloadOp); 116 if (!call) { 117 DiagnosedSilenceableFailure diag = 118 emitSilenceableError() << "only applies to func.call payloads"; 119 diag.attachNote(payloadOp->getLoc()) << "offending payload"; 120 return diag; 121 } 122 123 updateCallee(call, getNewTarget()); 124 } 125 126 // If everything went well, return success. 127 return DiagnosedSilenceableFailure::success(); 128 } 129 130 void mlir::transform::ChangeCallTargetOp::getEffects( 131 ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { 132 // Indicate that the `call` handle is only read by this operation because the 133 // associated operation is not erased but rather modified in-place, so the 134 // reference to it remains valid. 135 onlyReadsHandle(getCallMutable(), effects); 136 137 // Indicate that the payload is modified by this operation. 138 modifiesPayload(effects); 139 } 140 141 void registerMyExtension(::mlir::DialectRegistry ®istry) { 142 registry.addExtensions<MyExtension>(); 143 } 144