xref: /llvm-project/mlir/examples/transform/Ch3/lib/MyExtension.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
168ae0d78SAlex Zinenko //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
268ae0d78SAlex Zinenko //
368ae0d78SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
468ae0d78SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
568ae0d78SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
668ae0d78SAlex Zinenko //
768ae0d78SAlex Zinenko //===----------------------------------------------------------------------===//
868ae0d78SAlex Zinenko //
968ae0d78SAlex Zinenko // This file defines Transform dialect extension operations used in the
1068ae0d78SAlex Zinenko // Chapter 3 of the Transform dialect tutorial.
1168ae0d78SAlex Zinenko //
1268ae0d78SAlex Zinenko //===----------------------------------------------------------------------===//
1368ae0d78SAlex Zinenko 
1468ae0d78SAlex Zinenko #include "MyExtension.h"
1568ae0d78SAlex Zinenko #include "mlir/Dialect/Func/IR/FuncOps.h"
1668ae0d78SAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
1768ae0d78SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h"
185a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/IR/TransformTypes.h"
1968ae0d78SAlex Zinenko #include "mlir/IR/DialectImplementation.h"
2068ae0d78SAlex Zinenko #include "mlir/Interfaces/CallInterfaces.h"
2168ae0d78SAlex Zinenko #include "llvm/ADT/TypeSwitch.h"
2268ae0d78SAlex Zinenko 
2368ae0d78SAlex Zinenko #define GET_TYPEDEF_CLASSES
2468ae0d78SAlex Zinenko #include "MyExtensionTypes.cpp.inc"
2568ae0d78SAlex Zinenko 
2668ae0d78SAlex Zinenko #define GET_OP_CLASSES
2768ae0d78SAlex Zinenko #include "MyExtension.cpp.inc"
2868ae0d78SAlex Zinenko 
2968ae0d78SAlex Zinenko //===---------------------------------------------------------------------===//
3068ae0d78SAlex Zinenko // MyExtension
3168ae0d78SAlex Zinenko //===---------------------------------------------------------------------===//
3268ae0d78SAlex Zinenko 
3368ae0d78SAlex Zinenko // Define a new transform dialect extension. This uses the CRTP idiom to
3468ae0d78SAlex Zinenko // identify extensions.
3568ae0d78SAlex Zinenko class MyExtension
3668ae0d78SAlex Zinenko     : public ::mlir::transform::TransformDialectExtension<MyExtension> {
3768ae0d78SAlex Zinenko public:
38*84cc1865SNikhil Kalra   // The TypeID of this extension.
39*84cc1865SNikhil Kalra   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
40*84cc1865SNikhil Kalra 
4168ae0d78SAlex Zinenko   // The extension must derive the base constructor.
4268ae0d78SAlex Zinenko   using Base::Base;
4368ae0d78SAlex Zinenko 
4468ae0d78SAlex Zinenko   // This function initializes the extension, similarly to `initialize` in
4568ae0d78SAlex Zinenko   // dialect definitions. List individual operations and dependent dialects
4668ae0d78SAlex Zinenko   // here.
4768ae0d78SAlex Zinenko   void init();
4868ae0d78SAlex Zinenko };
4968ae0d78SAlex Zinenko 
5068ae0d78SAlex Zinenko void MyExtension::init() {
5168ae0d78SAlex Zinenko   // Similarly to dialects, an extension can declare a dependent dialect. This
5268ae0d78SAlex Zinenko   // dialect will be loaded along with the extension and, therefore, along with
5368ae0d78SAlex Zinenko   // the Transform dialect. Only declare as dependent the dialects that contain
5468ae0d78SAlex Zinenko   // the attributes or types used by transform operations. Do NOT declare as
5568ae0d78SAlex Zinenko   // dependent the dialects produced during the transformation.
5668ae0d78SAlex Zinenko   // declareDependentDialect<MyDialect>();
5768ae0d78SAlex Zinenko 
5868ae0d78SAlex Zinenko   // When transformations are applied, they may produce new operations from
5968ae0d78SAlex Zinenko   // previously unloaded dialects. Typically, a pass would need to declare
6068ae0d78SAlex Zinenko   // itself dependent on the dialects containing such new operations. To avoid
6168ae0d78SAlex Zinenko   // confusion with the dialects the extension itself depends on, the Transform
6268ae0d78SAlex Zinenko   // dialects differentiates between:
6368ae0d78SAlex Zinenko   //   - dependent dialects, which are used by the transform operations, and
6468ae0d78SAlex Zinenko   //   - generated dialects, which contain the entities (attributes, operations,
6568ae0d78SAlex Zinenko   //     types) that may be produced by applying the transformation even when
6668ae0d78SAlex Zinenko   //     not present in the original payload IR.
6768ae0d78SAlex Zinenko   // In the following chapter, we will be add operations that generate function
6868ae0d78SAlex Zinenko   // calls and structured control flow operations, so let's declare the
6968ae0d78SAlex Zinenko   // corresponding dialects as generated.
7068ae0d78SAlex Zinenko   declareGeneratedDialect<::mlir::scf::SCFDialect>();
7168ae0d78SAlex Zinenko   declareGeneratedDialect<::mlir::func::FuncDialect>();
7268ae0d78SAlex Zinenko 
7368ae0d78SAlex Zinenko   // Register the additional transform dialect types with the dialect. List all
7468ae0d78SAlex Zinenko   // types generated from ODS.
7568ae0d78SAlex Zinenko   registerTypes<
7668ae0d78SAlex Zinenko #define GET_TYPEDEF_LIST
7768ae0d78SAlex Zinenko #include "MyExtensionTypes.cpp.inc"
7868ae0d78SAlex Zinenko       >();
7968ae0d78SAlex Zinenko 
8068ae0d78SAlex Zinenko   // ODS generates these helpers for type printing and parsing, but the
8168ae0d78SAlex Zinenko   // Transform dialect provides its own support for types supplied by the
8268ae0d78SAlex Zinenko   // extension. Reference these functions to avoid a compiler warning.
837cdb875dSAlex Zinenko   (void)&generatedTypeParser;
847cdb875dSAlex Zinenko   (void)&generatedTypePrinter;
8568ae0d78SAlex Zinenko 
8668ae0d78SAlex Zinenko   // Finally, we register the additional transform operations with the dialect.
8768ae0d78SAlex Zinenko   // List all operations generated from ODS. This call will perform additional
8868ae0d78SAlex Zinenko   // checks that the operations implement the transform and memory effect
8968ae0d78SAlex Zinenko   // interfaces required by the dialect interpreter and assert if they do not.
9068ae0d78SAlex Zinenko   registerTransformOps<
9168ae0d78SAlex Zinenko #define GET_OP_LIST
9268ae0d78SAlex Zinenko #include "MyExtension.cpp.inc"
9368ae0d78SAlex Zinenko       >();
9468ae0d78SAlex Zinenko }
9568ae0d78SAlex Zinenko 
9668ae0d78SAlex Zinenko //===---------------------------------------------------------------------===//
9768ae0d78SAlex Zinenko // ChangeCallTargetOp
9868ae0d78SAlex Zinenko //===---------------------------------------------------------------------===//
9968ae0d78SAlex Zinenko 
10068ae0d78SAlex Zinenko static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) {
10168ae0d78SAlex Zinenko   call.setCallee(newTarget);
10268ae0d78SAlex Zinenko }
10368ae0d78SAlex Zinenko 
10468ae0d78SAlex Zinenko // Implementation of our transform dialect operation.
10568ae0d78SAlex Zinenko // This operation returns a tri-state result that can be one of:
10668ae0d78SAlex Zinenko // - success when the transformation succeeded;
10768ae0d78SAlex Zinenko // - definite failure when the transformation failed in such a way that
10868ae0d78SAlex Zinenko // following
10968ae0d78SAlex Zinenko //   transformations are impossible or undesirable, typically it could have left
11068ae0d78SAlex Zinenko //   payload IR in an invalid state; it is expected that a diagnostic is emitted
11168ae0d78SAlex Zinenko //   immediately before returning the definite error;
11268ae0d78SAlex Zinenko // - silenceable failure when the transformation failed but following
11368ae0d78SAlex Zinenko // transformations
11468ae0d78SAlex Zinenko //   are still applicable, typically this means a precondition for the
11568ae0d78SAlex Zinenko //   transformation is not satisfied and the payload IR has not been modified.
11668ae0d78SAlex Zinenko // The silenceable failure additionally carries a Diagnostic that can be emitted
11768ae0d78SAlex Zinenko // to the user.
11868ae0d78SAlex Zinenko ::mlir::DiagnosedSilenceableFailure
11968ae0d78SAlex Zinenko mlir::transform::ChangeCallTargetOp::applyToOne(
120c63d2b2cSMatthias Springer     // The rewriter that should be used when modifying IR.
121c63d2b2cSMatthias Springer     ::mlir::transform::TransformRewriter &rewriter,
12268ae0d78SAlex Zinenko     // The single payload operation to which the transformation is applied.
12368ae0d78SAlex Zinenko     ::mlir::func::CallOp call,
12468ae0d78SAlex Zinenko     // The payload IR entities that will be appended to lists associated with
12568ae0d78SAlex Zinenko     // the results of this transform operation. This list contains one entry per
12668ae0d78SAlex Zinenko     // result.
12768ae0d78SAlex Zinenko     ::mlir::transform::ApplyToEachResultList &results,
12868ae0d78SAlex Zinenko     // The transform application state. This object can be used to query the
12968ae0d78SAlex Zinenko     // current associations between transform IR values and payload IR entities.
13068ae0d78SAlex Zinenko     // It can also carry additional user-defined state.
13168ae0d78SAlex Zinenko     ::mlir::transform::TransformState &state) {
13268ae0d78SAlex Zinenko 
13368ae0d78SAlex Zinenko   // Dispatch to the actual transformation.
13468ae0d78SAlex Zinenko   updateCallee(call, getNewTarget());
13568ae0d78SAlex Zinenko 
13668ae0d78SAlex Zinenko   // If everything went well, return success.
13768ae0d78SAlex Zinenko   return DiagnosedSilenceableFailure::success();
13868ae0d78SAlex Zinenko }
13968ae0d78SAlex Zinenko 
14068ae0d78SAlex Zinenko void mlir::transform::ChangeCallTargetOp::getEffects(
14168ae0d78SAlex Zinenko     ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
14268ae0d78SAlex Zinenko   // Indicate that the `call` handle is only read by this operation because the
14368ae0d78SAlex Zinenko   // associated operation is not erased but rather modified in-place, so the
14468ae0d78SAlex Zinenko   // reference to it remains valid.
1452c1ae801Sdonald chen   onlyReadsHandle(getCallMutable(), effects);
14668ae0d78SAlex Zinenko 
14768ae0d78SAlex Zinenko   // Indicate that the payload is modified by this operation.
14868ae0d78SAlex Zinenko   modifiesPayload(effects);
14968ae0d78SAlex Zinenko }
15068ae0d78SAlex Zinenko 
15168ae0d78SAlex Zinenko //===---------------------------------------------------------------------===//
15268ae0d78SAlex Zinenko // CallToOp
15368ae0d78SAlex Zinenko //===---------------------------------------------------------------------===//
15468ae0d78SAlex Zinenko 
155c63d2b2cSMatthias Springer static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter,
156c63d2b2cSMatthias Springer                                           mlir::CallOpInterface call) {
15768ae0d78SAlex Zinenko   // Construct an operation from an unregistered dialect. This is discouraged
15868ae0d78SAlex Zinenko   // and is only used here for brevity of the overall example.
15968ae0d78SAlex Zinenko   mlir::OperationState state(call.getLoc(), "my.mm4");
16068ae0d78SAlex Zinenko   state.types.assign(call->result_type_begin(), call->result_type_end());
16168ae0d78SAlex Zinenko   state.operands.assign(call->operand_begin(), call->operand_end());
16268ae0d78SAlex Zinenko 
163c63d2b2cSMatthias Springer   mlir::Operation *replacement = rewriter.create(state);
164c63d2b2cSMatthias Springer   rewriter.replaceOp(call, replacement->getResults());
16568ae0d78SAlex Zinenko   return replacement;
16668ae0d78SAlex Zinenko }
16768ae0d78SAlex Zinenko 
16868ae0d78SAlex Zinenko // See above for the signature description.
16968ae0d78SAlex Zinenko mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne(
170c63d2b2cSMatthias Springer     mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call,
171c63d2b2cSMatthias Springer     mlir::transform::ApplyToEachResultList &results,
17268ae0d78SAlex Zinenko     mlir::transform::TransformState &state) {
17368ae0d78SAlex Zinenko 
17468ae0d78SAlex Zinenko   // Dispatch to the actual transformation.
175c63d2b2cSMatthias Springer   Operation *replacement = replaceCallWithOp(rewriter, call);
17668ae0d78SAlex Zinenko 
17768ae0d78SAlex Zinenko   // Associate the payload operation produced by the rewrite with the result
17868ae0d78SAlex Zinenko   // handle of this transform operation.
17968ae0d78SAlex Zinenko   results.push_back(replacement);
18068ae0d78SAlex Zinenko 
18168ae0d78SAlex Zinenko   // If everything went well, return success.
18268ae0d78SAlex Zinenko   return DiagnosedSilenceableFailure::success();
18368ae0d78SAlex Zinenko }
18468ae0d78SAlex Zinenko 
18568ae0d78SAlex Zinenko //===---------------------------------------------------------------------===//
18668ae0d78SAlex Zinenko // CallOpInterfaceHandleType
18768ae0d78SAlex Zinenko //===---------------------------------------------------------------------===//
18868ae0d78SAlex Zinenko 
18968ae0d78SAlex Zinenko // The interface declares this method to verify constraints this type has on
19068ae0d78SAlex Zinenko // payload operations. It returns the now familiar tri-state result.
19168ae0d78SAlex Zinenko mlir::DiagnosedSilenceableFailure
19268ae0d78SAlex Zinenko mlir::transform::CallOpInterfaceHandleType::checkPayload(
19368ae0d78SAlex Zinenko     // Location at which diagnostics should be emitted.
19468ae0d78SAlex Zinenko     mlir::Location loc,
19568ae0d78SAlex Zinenko     // List of payload operations that are about to be associated with the
19668ae0d78SAlex Zinenko     // handle that has this type.
19768ae0d78SAlex Zinenko     llvm::ArrayRef<mlir::Operation *> payload) const {
19868ae0d78SAlex Zinenko 
19968ae0d78SAlex Zinenko   // All payload operations are expected to implement CallOpInterface, check
20068ae0d78SAlex Zinenko   // this.
20168ae0d78SAlex Zinenko   for (Operation *op : payload) {
20268ae0d78SAlex Zinenko     if (llvm::isa<mlir::CallOpInterface>(op))
20368ae0d78SAlex Zinenko       continue;
20468ae0d78SAlex Zinenko 
20568ae0d78SAlex Zinenko     // By convention, these verifiers always emit a silenceable failure since
20668ae0d78SAlex Zinenko     // they are checking a precondition.
20768ae0d78SAlex Zinenko     DiagnosedSilenceableFailure diag =
20868ae0d78SAlex Zinenko         emitSilenceableError(loc)
20968ae0d78SAlex Zinenko         << "expected the payload operation to implement CallOpInterface";
21068ae0d78SAlex Zinenko     diag.attachNote(op->getLoc()) << "offending operation";
21168ae0d78SAlex Zinenko     return diag;
21268ae0d78SAlex Zinenko   }
21368ae0d78SAlex Zinenko 
21468ae0d78SAlex Zinenko   // If everything is okay, return success.
21568ae0d78SAlex Zinenko   return DiagnosedSilenceableFailure::success();
21668ae0d78SAlex Zinenko }
21768ae0d78SAlex Zinenko 
21868ae0d78SAlex Zinenko //===---------------------------------------------------------------------===//
21968ae0d78SAlex Zinenko // Extension registration
22068ae0d78SAlex Zinenko //===---------------------------------------------------------------------===//
22168ae0d78SAlex Zinenko 
22268ae0d78SAlex Zinenko void registerMyExtension(::mlir::DialectRegistry &registry) {
22368ae0d78SAlex Zinenko   registry.addExtensions<MyExtension>();
22468ae0d78SAlex Zinenko }
225