168ae0d78SAlex Zinenko //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// 268ae0d78SAlex Zinenko // 368ae0d78SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 468ae0d78SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 568ae0d78SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 668ae0d78SAlex Zinenko // 768ae0d78SAlex Zinenko //===----------------------------------------------------------------------===// 868ae0d78SAlex Zinenko // 968ae0d78SAlex Zinenko // This file defines Transform dialect extension operations used in the 1068ae0d78SAlex Zinenko // Chapter 2 of the Transform dialect tutorial. 1168ae0d78SAlex Zinenko // 1268ae0d78SAlex Zinenko //===----------------------------------------------------------------------===// 1368ae0d78SAlex Zinenko 1468ae0d78SAlex Zinenko #include "MyExtension.h" 1568ae0d78SAlex Zinenko #include "mlir/Dialect/Func/IR/FuncOps.h" 1668ae0d78SAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 1768ae0d78SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18ec6da065SMehdi Amini #include "mlir/Dialect/Transform/IR/TransformTypes.h" 195a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 20ec6da065SMehdi Amini #include "mlir/IR/DialectRegistry.h" 21ec6da065SMehdi Amini #include "mlir/IR/Operation.h" 22ec6da065SMehdi Amini #include "mlir/Interfaces/SideEffectInterfaces.h" 23ec6da065SMehdi Amini #include "mlir/Support/LLVM.h" 24ec6da065SMehdi Amini #include "llvm/ADT/SmallVector.h" 25ec6da065SMehdi Amini #include "llvm/ADT/StringRef.h" 2668ae0d78SAlex Zinenko 2768ae0d78SAlex Zinenko // Define a new transform dialect extension. This uses the CRTP idiom to 2868ae0d78SAlex Zinenko // identify extensions. 2968ae0d78SAlex Zinenko class MyExtension 3068ae0d78SAlex Zinenko : public ::mlir::transform::TransformDialectExtension<MyExtension> { 3168ae0d78SAlex Zinenko public: 32*84cc1865SNikhil Kalra // The TypeID of this extension. 33*84cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) 34*84cc1865SNikhil Kalra 3568ae0d78SAlex Zinenko // The extension must derive the base constructor. 3668ae0d78SAlex Zinenko using Base::Base; 3768ae0d78SAlex Zinenko 3868ae0d78SAlex Zinenko // This function initializes the extension, similarly to `initialize` in 3968ae0d78SAlex Zinenko // dialect definitions. List individual operations and dependent dialects 4068ae0d78SAlex Zinenko // here. 4168ae0d78SAlex Zinenko void init(); 4268ae0d78SAlex Zinenko }; 4368ae0d78SAlex Zinenko 4468ae0d78SAlex Zinenko void MyExtension::init() { 4568ae0d78SAlex Zinenko // Similarly to dialects, an extension can declare a dependent dialect. This 4668ae0d78SAlex Zinenko // dialect will be loaded along with the extension and, therefore, along with 4768ae0d78SAlex Zinenko // the Transform dialect. Only declare as dependent the dialects that contain 4868ae0d78SAlex Zinenko // the attributes or types used by transform operations. Do NOT declare as 4968ae0d78SAlex Zinenko // dependent the dialects produced during the transformation. 5068ae0d78SAlex Zinenko // declareDependentDialect<MyDialect>(); 5168ae0d78SAlex Zinenko 5268ae0d78SAlex Zinenko // When transformations are applied, they may produce new operations from 5368ae0d78SAlex Zinenko // previously unloaded dialects. Typically, a pass would need to declare 5468ae0d78SAlex Zinenko // itself dependent on the dialects containing such new operations. To avoid 5568ae0d78SAlex Zinenko // confusion with the dialects the extension itself depends on, the Transform 5668ae0d78SAlex Zinenko // dialects differentiates between: 5768ae0d78SAlex Zinenko // - dependent dialects, which are used by the transform operations, and 5868ae0d78SAlex Zinenko // - generated dialects, which contain the entities (attributes, operations, 5968ae0d78SAlex Zinenko // types) that may be produced by applying the transformation even when 6068ae0d78SAlex Zinenko // not present in the original payload IR. 6168ae0d78SAlex Zinenko // In the following chapter, we will be add operations that generate function 6268ae0d78SAlex Zinenko // calls and structured control flow operations, so let's declare the 6368ae0d78SAlex Zinenko // corresponding dialects as generated. 6468ae0d78SAlex Zinenko declareGeneratedDialect<::mlir::scf::SCFDialect>(); 6568ae0d78SAlex Zinenko declareGeneratedDialect<::mlir::func::FuncDialect>(); 6668ae0d78SAlex Zinenko 6768ae0d78SAlex Zinenko // Finally, we register the additional transform operations with the dialect. 6868ae0d78SAlex Zinenko // List all operations generated from ODS. This call will perform additional 6968ae0d78SAlex Zinenko // checks that the operations implement the transform and memory effect 7068ae0d78SAlex Zinenko // interfaces required by the dialect interpreter and assert if they do not. 7168ae0d78SAlex Zinenko registerTransformOps< 7268ae0d78SAlex Zinenko #define GET_OP_LIST 7368ae0d78SAlex Zinenko #include "MyExtension.cpp.inc" 7468ae0d78SAlex Zinenko >(); 7568ae0d78SAlex Zinenko } 7668ae0d78SAlex Zinenko 7768ae0d78SAlex Zinenko #define GET_OP_CLASSES 7868ae0d78SAlex Zinenko #include "MyExtension.cpp.inc" 7968ae0d78SAlex Zinenko 8068ae0d78SAlex Zinenko static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { 8168ae0d78SAlex Zinenko call.setCallee(newTarget); 8268ae0d78SAlex Zinenko } 8368ae0d78SAlex Zinenko 8468ae0d78SAlex Zinenko // Implementation of our transform dialect operation. 8568ae0d78SAlex Zinenko // This operation returns a tri-state result that can be one of: 8668ae0d78SAlex Zinenko // - success when the transformation succeeded; 8768ae0d78SAlex Zinenko // - definite failure when the transformation failed in such a way that 88c63d2b2cSMatthias Springer // following transformations are impossible or undesirable, typically it could 89c63d2b2cSMatthias Springer // have left payload IR in an invalid state; it is expected that a diagnostic 90c63d2b2cSMatthias Springer // is emitted immediately before returning the definite error; 9168ae0d78SAlex Zinenko // - silenceable failure when the transformation failed but following 92c63d2b2cSMatthias Springer // transformations are still applicable, typically this means a precondition 93c63d2b2cSMatthias Springer // for the transformation is not satisfied and the payload IR has not been 94c63d2b2cSMatthias Springer // modified. The silenceable failure additionally carries a Diagnostic that 95c63d2b2cSMatthias Springer // can be emitted to the user. 9668ae0d78SAlex Zinenko ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( 97c63d2b2cSMatthias Springer // The rewriter that should be used when modifying IR. 98c63d2b2cSMatthias Springer ::mlir::transform::TransformRewriter &rewriter, 9968ae0d78SAlex Zinenko // The list of payload IR entities that will be associated with the 10068ae0d78SAlex Zinenko // transform IR values defined by this transform operation. In this case, it 10168ae0d78SAlex Zinenko // can remain empty as there are no results. 10268ae0d78SAlex Zinenko ::mlir::transform::TransformResults &results, 10368ae0d78SAlex Zinenko // The transform application state. This object can be used to query the 10468ae0d78SAlex Zinenko // current associations between transform IR values and payload IR entities. 10568ae0d78SAlex Zinenko // It can also carry additional user-defined state. 10668ae0d78SAlex Zinenko ::mlir::transform::TransformState &state) { 10768ae0d78SAlex Zinenko 10868ae0d78SAlex Zinenko // First, we need to obtain the list of payload operations that are associated 10968ae0d78SAlex Zinenko // with the operand handle. 11068ae0d78SAlex Zinenko auto payload = state.getPayloadOps(getCall()); 11168ae0d78SAlex Zinenko 11268ae0d78SAlex Zinenko // Then, we iterate over the list of operands and call the actual IR-mutating 11368ae0d78SAlex Zinenko // function. We also check the preconditions here. 11468ae0d78SAlex Zinenko for (Operation *payloadOp : payload) { 11568ae0d78SAlex Zinenko auto call = dyn_cast<::mlir::func::CallOp>(payloadOp); 11668ae0d78SAlex Zinenko if (!call) { 11768ae0d78SAlex Zinenko DiagnosedSilenceableFailure diag = 11868ae0d78SAlex Zinenko emitSilenceableError() << "only applies to func.call payloads"; 11968ae0d78SAlex Zinenko diag.attachNote(payloadOp->getLoc()) << "offending payload"; 12068ae0d78SAlex Zinenko return diag; 12168ae0d78SAlex Zinenko } 12268ae0d78SAlex Zinenko 12368ae0d78SAlex Zinenko updateCallee(call, getNewTarget()); 12468ae0d78SAlex Zinenko } 12568ae0d78SAlex Zinenko 12668ae0d78SAlex Zinenko // If everything went well, return success. 12768ae0d78SAlex Zinenko return DiagnosedSilenceableFailure::success(); 12868ae0d78SAlex Zinenko } 12968ae0d78SAlex Zinenko 13068ae0d78SAlex Zinenko void mlir::transform::ChangeCallTargetOp::getEffects( 13168ae0d78SAlex Zinenko ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { 13268ae0d78SAlex Zinenko // Indicate that the `call` handle is only read by this operation because the 13368ae0d78SAlex Zinenko // associated operation is not erased but rather modified in-place, so the 13468ae0d78SAlex Zinenko // reference to it remains valid. 1352c1ae801Sdonald chen onlyReadsHandle(getCallMutable(), effects); 13668ae0d78SAlex Zinenko 13768ae0d78SAlex Zinenko // Indicate that the payload is modified by this operation. 13868ae0d78SAlex Zinenko modifiesPayload(effects); 13968ae0d78SAlex Zinenko } 14068ae0d78SAlex Zinenko 14168ae0d78SAlex Zinenko void registerMyExtension(::mlir::DialectRegistry ®istry) { 14268ae0d78SAlex Zinenko registry.addExtensions<MyExtension>(); 14368ae0d78SAlex Zinenko } 144