168ae0d78SAlex Zinenko //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// 268ae0d78SAlex Zinenko // 368ae0d78SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 468ae0d78SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 568ae0d78SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 668ae0d78SAlex Zinenko // 768ae0d78SAlex Zinenko //===----------------------------------------------------------------------===// 868ae0d78SAlex Zinenko // 968ae0d78SAlex Zinenko // This file defines Transform dialect extension operations used in the 1068ae0d78SAlex Zinenko // Chapter 3 of the Transform dialect tutorial. 1168ae0d78SAlex Zinenko // 1268ae0d78SAlex Zinenko //===----------------------------------------------------------------------===// 1368ae0d78SAlex Zinenko 1468ae0d78SAlex Zinenko #include "MyExtension.h" 1568ae0d78SAlex Zinenko #include "mlir/Dialect/Func/IR/FuncOps.h" 1668ae0d78SAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 1768ae0d78SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h" 185a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/IR/TransformTypes.h" 1968ae0d78SAlex Zinenko #include "mlir/IR/DialectImplementation.h" 2068ae0d78SAlex Zinenko #include "mlir/Interfaces/CallInterfaces.h" 2168ae0d78SAlex Zinenko #include "llvm/ADT/TypeSwitch.h" 2268ae0d78SAlex Zinenko 2368ae0d78SAlex Zinenko #define GET_TYPEDEF_CLASSES 2468ae0d78SAlex Zinenko #include "MyExtensionTypes.cpp.inc" 2568ae0d78SAlex Zinenko 2668ae0d78SAlex Zinenko #define GET_OP_CLASSES 2768ae0d78SAlex Zinenko #include "MyExtension.cpp.inc" 2868ae0d78SAlex Zinenko 2968ae0d78SAlex Zinenko //===---------------------------------------------------------------------===// 3068ae0d78SAlex Zinenko // MyExtension 3168ae0d78SAlex Zinenko //===---------------------------------------------------------------------===// 3268ae0d78SAlex Zinenko 3368ae0d78SAlex Zinenko // Define a new transform dialect extension. This uses the CRTP idiom to 3468ae0d78SAlex Zinenko // identify extensions. 3568ae0d78SAlex Zinenko class MyExtension 3668ae0d78SAlex Zinenko : public ::mlir::transform::TransformDialectExtension<MyExtension> { 3768ae0d78SAlex Zinenko public: 38*84cc1865SNikhil Kalra // The TypeID of this extension. 39*84cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) 40*84cc1865SNikhil Kalra 4168ae0d78SAlex Zinenko // The extension must derive the base constructor. 4268ae0d78SAlex Zinenko using Base::Base; 4368ae0d78SAlex Zinenko 4468ae0d78SAlex Zinenko // This function initializes the extension, similarly to `initialize` in 4568ae0d78SAlex Zinenko // dialect definitions. List individual operations and dependent dialects 4668ae0d78SAlex Zinenko // here. 4768ae0d78SAlex Zinenko void init(); 4868ae0d78SAlex Zinenko }; 4968ae0d78SAlex Zinenko 5068ae0d78SAlex Zinenko void MyExtension::init() { 5168ae0d78SAlex Zinenko // Similarly to dialects, an extension can declare a dependent dialect. This 5268ae0d78SAlex Zinenko // dialect will be loaded along with the extension and, therefore, along with 5368ae0d78SAlex Zinenko // the Transform dialect. Only declare as dependent the dialects that contain 5468ae0d78SAlex Zinenko // the attributes or types used by transform operations. Do NOT declare as 5568ae0d78SAlex Zinenko // dependent the dialects produced during the transformation. 5668ae0d78SAlex Zinenko // declareDependentDialect<MyDialect>(); 5768ae0d78SAlex Zinenko 5868ae0d78SAlex Zinenko // When transformations are applied, they may produce new operations from 5968ae0d78SAlex Zinenko // previously unloaded dialects. Typically, a pass would need to declare 6068ae0d78SAlex Zinenko // itself dependent on the dialects containing such new operations. To avoid 6168ae0d78SAlex Zinenko // confusion with the dialects the extension itself depends on, the Transform 6268ae0d78SAlex Zinenko // dialects differentiates between: 6368ae0d78SAlex Zinenko // - dependent dialects, which are used by the transform operations, and 6468ae0d78SAlex Zinenko // - generated dialects, which contain the entities (attributes, operations, 6568ae0d78SAlex Zinenko // types) that may be produced by applying the transformation even when 6668ae0d78SAlex Zinenko // not present in the original payload IR. 6768ae0d78SAlex Zinenko // In the following chapter, we will be add operations that generate function 6868ae0d78SAlex Zinenko // calls and structured control flow operations, so let's declare the 6968ae0d78SAlex Zinenko // corresponding dialects as generated. 7068ae0d78SAlex Zinenko declareGeneratedDialect<::mlir::scf::SCFDialect>(); 7168ae0d78SAlex Zinenko declareGeneratedDialect<::mlir::func::FuncDialect>(); 7268ae0d78SAlex Zinenko 7368ae0d78SAlex Zinenko // Register the additional transform dialect types with the dialect. List all 7468ae0d78SAlex Zinenko // types generated from ODS. 7568ae0d78SAlex Zinenko registerTypes< 7668ae0d78SAlex Zinenko #define GET_TYPEDEF_LIST 7768ae0d78SAlex Zinenko #include "MyExtensionTypes.cpp.inc" 7868ae0d78SAlex Zinenko >(); 7968ae0d78SAlex Zinenko 8068ae0d78SAlex Zinenko // ODS generates these helpers for type printing and parsing, but the 8168ae0d78SAlex Zinenko // Transform dialect provides its own support for types supplied by the 8268ae0d78SAlex Zinenko // extension. Reference these functions to avoid a compiler warning. 837cdb875dSAlex Zinenko (void)&generatedTypeParser; 847cdb875dSAlex Zinenko (void)&generatedTypePrinter; 8568ae0d78SAlex Zinenko 8668ae0d78SAlex Zinenko // Finally, we register the additional transform operations with the dialect. 8768ae0d78SAlex Zinenko // List all operations generated from ODS. This call will perform additional 8868ae0d78SAlex Zinenko // checks that the operations implement the transform and memory effect 8968ae0d78SAlex Zinenko // interfaces required by the dialect interpreter and assert if they do not. 9068ae0d78SAlex Zinenko registerTransformOps< 9168ae0d78SAlex Zinenko #define GET_OP_LIST 9268ae0d78SAlex Zinenko #include "MyExtension.cpp.inc" 9368ae0d78SAlex Zinenko >(); 9468ae0d78SAlex Zinenko } 9568ae0d78SAlex Zinenko 9668ae0d78SAlex Zinenko //===---------------------------------------------------------------------===// 9768ae0d78SAlex Zinenko // ChangeCallTargetOp 9868ae0d78SAlex Zinenko //===---------------------------------------------------------------------===// 9968ae0d78SAlex Zinenko 10068ae0d78SAlex Zinenko static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { 10168ae0d78SAlex Zinenko call.setCallee(newTarget); 10268ae0d78SAlex Zinenko } 10368ae0d78SAlex Zinenko 10468ae0d78SAlex Zinenko // Implementation of our transform dialect operation. 10568ae0d78SAlex Zinenko // This operation returns a tri-state result that can be one of: 10668ae0d78SAlex Zinenko // - success when the transformation succeeded; 10768ae0d78SAlex Zinenko // - definite failure when the transformation failed in such a way that 10868ae0d78SAlex Zinenko // following 10968ae0d78SAlex Zinenko // transformations are impossible or undesirable, typically it could have left 11068ae0d78SAlex Zinenko // payload IR in an invalid state; it is expected that a diagnostic is emitted 11168ae0d78SAlex Zinenko // immediately before returning the definite error; 11268ae0d78SAlex Zinenko // - silenceable failure when the transformation failed but following 11368ae0d78SAlex Zinenko // transformations 11468ae0d78SAlex Zinenko // are still applicable, typically this means a precondition for the 11568ae0d78SAlex Zinenko // transformation is not satisfied and the payload IR has not been modified. 11668ae0d78SAlex Zinenko // The silenceable failure additionally carries a Diagnostic that can be emitted 11768ae0d78SAlex Zinenko // to the user. 11868ae0d78SAlex Zinenko ::mlir::DiagnosedSilenceableFailure 11968ae0d78SAlex Zinenko mlir::transform::ChangeCallTargetOp::applyToOne( 120c63d2b2cSMatthias Springer // The rewriter that should be used when modifying IR. 121c63d2b2cSMatthias Springer ::mlir::transform::TransformRewriter &rewriter, 12268ae0d78SAlex Zinenko // The single payload operation to which the transformation is applied. 12368ae0d78SAlex Zinenko ::mlir::func::CallOp call, 12468ae0d78SAlex Zinenko // The payload IR entities that will be appended to lists associated with 12568ae0d78SAlex Zinenko // the results of this transform operation. This list contains one entry per 12668ae0d78SAlex Zinenko // result. 12768ae0d78SAlex Zinenko ::mlir::transform::ApplyToEachResultList &results, 12868ae0d78SAlex Zinenko // The transform application state. This object can be used to query the 12968ae0d78SAlex Zinenko // current associations between transform IR values and payload IR entities. 13068ae0d78SAlex Zinenko // It can also carry additional user-defined state. 13168ae0d78SAlex Zinenko ::mlir::transform::TransformState &state) { 13268ae0d78SAlex Zinenko 13368ae0d78SAlex Zinenko // Dispatch to the actual transformation. 13468ae0d78SAlex Zinenko updateCallee(call, getNewTarget()); 13568ae0d78SAlex Zinenko 13668ae0d78SAlex Zinenko // If everything went well, return success. 13768ae0d78SAlex Zinenko return DiagnosedSilenceableFailure::success(); 13868ae0d78SAlex Zinenko } 13968ae0d78SAlex Zinenko 14068ae0d78SAlex Zinenko void mlir::transform::ChangeCallTargetOp::getEffects( 14168ae0d78SAlex Zinenko ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { 14268ae0d78SAlex Zinenko // Indicate that the `call` handle is only read by this operation because the 14368ae0d78SAlex Zinenko // associated operation is not erased but rather modified in-place, so the 14468ae0d78SAlex Zinenko // reference to it remains valid. 1452c1ae801Sdonald chen onlyReadsHandle(getCallMutable(), effects); 14668ae0d78SAlex Zinenko 14768ae0d78SAlex Zinenko // Indicate that the payload is modified by this operation. 14868ae0d78SAlex Zinenko modifiesPayload(effects); 14968ae0d78SAlex Zinenko } 15068ae0d78SAlex Zinenko 15168ae0d78SAlex Zinenko //===---------------------------------------------------------------------===// 15268ae0d78SAlex Zinenko // CallToOp 15368ae0d78SAlex Zinenko //===---------------------------------------------------------------------===// 15468ae0d78SAlex Zinenko 155c63d2b2cSMatthias Springer static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter, 156c63d2b2cSMatthias Springer mlir::CallOpInterface call) { 15768ae0d78SAlex Zinenko // Construct an operation from an unregistered dialect. This is discouraged 15868ae0d78SAlex Zinenko // and is only used here for brevity of the overall example. 15968ae0d78SAlex Zinenko mlir::OperationState state(call.getLoc(), "my.mm4"); 16068ae0d78SAlex Zinenko state.types.assign(call->result_type_begin(), call->result_type_end()); 16168ae0d78SAlex Zinenko state.operands.assign(call->operand_begin(), call->operand_end()); 16268ae0d78SAlex Zinenko 163c63d2b2cSMatthias Springer mlir::Operation *replacement = rewriter.create(state); 164c63d2b2cSMatthias Springer rewriter.replaceOp(call, replacement->getResults()); 16568ae0d78SAlex Zinenko return replacement; 16668ae0d78SAlex Zinenko } 16768ae0d78SAlex Zinenko 16868ae0d78SAlex Zinenko // See above for the signature description. 16968ae0d78SAlex Zinenko mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne( 170c63d2b2cSMatthias Springer mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call, 171c63d2b2cSMatthias Springer mlir::transform::ApplyToEachResultList &results, 17268ae0d78SAlex Zinenko mlir::transform::TransformState &state) { 17368ae0d78SAlex Zinenko 17468ae0d78SAlex Zinenko // Dispatch to the actual transformation. 175c63d2b2cSMatthias Springer Operation *replacement = replaceCallWithOp(rewriter, call); 17668ae0d78SAlex Zinenko 17768ae0d78SAlex Zinenko // Associate the payload operation produced by the rewrite with the result 17868ae0d78SAlex Zinenko // handle of this transform operation. 17968ae0d78SAlex Zinenko results.push_back(replacement); 18068ae0d78SAlex Zinenko 18168ae0d78SAlex Zinenko // If everything went well, return success. 18268ae0d78SAlex Zinenko return DiagnosedSilenceableFailure::success(); 18368ae0d78SAlex Zinenko } 18468ae0d78SAlex Zinenko 18568ae0d78SAlex Zinenko //===---------------------------------------------------------------------===// 18668ae0d78SAlex Zinenko // CallOpInterfaceHandleType 18768ae0d78SAlex Zinenko //===---------------------------------------------------------------------===// 18868ae0d78SAlex Zinenko 18968ae0d78SAlex Zinenko // The interface declares this method to verify constraints this type has on 19068ae0d78SAlex Zinenko // payload operations. It returns the now familiar tri-state result. 19168ae0d78SAlex Zinenko mlir::DiagnosedSilenceableFailure 19268ae0d78SAlex Zinenko mlir::transform::CallOpInterfaceHandleType::checkPayload( 19368ae0d78SAlex Zinenko // Location at which diagnostics should be emitted. 19468ae0d78SAlex Zinenko mlir::Location loc, 19568ae0d78SAlex Zinenko // List of payload operations that are about to be associated with the 19668ae0d78SAlex Zinenko // handle that has this type. 19768ae0d78SAlex Zinenko llvm::ArrayRef<mlir::Operation *> payload) const { 19868ae0d78SAlex Zinenko 19968ae0d78SAlex Zinenko // All payload operations are expected to implement CallOpInterface, check 20068ae0d78SAlex Zinenko // this. 20168ae0d78SAlex Zinenko for (Operation *op : payload) { 20268ae0d78SAlex Zinenko if (llvm::isa<mlir::CallOpInterface>(op)) 20368ae0d78SAlex Zinenko continue; 20468ae0d78SAlex Zinenko 20568ae0d78SAlex Zinenko // By convention, these verifiers always emit a silenceable failure since 20668ae0d78SAlex Zinenko // they are checking a precondition. 20768ae0d78SAlex Zinenko DiagnosedSilenceableFailure diag = 20868ae0d78SAlex Zinenko emitSilenceableError(loc) 20968ae0d78SAlex Zinenko << "expected the payload operation to implement CallOpInterface"; 21068ae0d78SAlex Zinenko diag.attachNote(op->getLoc()) << "offending operation"; 21168ae0d78SAlex Zinenko return diag; 21268ae0d78SAlex Zinenko } 21368ae0d78SAlex Zinenko 21468ae0d78SAlex Zinenko // If everything is okay, return success. 21568ae0d78SAlex Zinenko return DiagnosedSilenceableFailure::success(); 21668ae0d78SAlex Zinenko } 21768ae0d78SAlex Zinenko 21868ae0d78SAlex Zinenko //===---------------------------------------------------------------------===// 21968ae0d78SAlex Zinenko // Extension registration 22068ae0d78SAlex Zinenko //===---------------------------------------------------------------------===// 22168ae0d78SAlex Zinenko 22268ae0d78SAlex Zinenko void registerMyExtension(::mlir::DialectRegistry ®istry) { 22368ae0d78SAlex Zinenko registry.addExtensions<MyExtension>(); 22468ae0d78SAlex Zinenko } 225