xref: /llvm-project/mlir/examples/transform/Ch2/lib/MyExtension.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
168ae0d78SAlex Zinenko //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
268ae0d78SAlex Zinenko //
368ae0d78SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
468ae0d78SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
568ae0d78SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
668ae0d78SAlex Zinenko //
768ae0d78SAlex Zinenko //===----------------------------------------------------------------------===//
868ae0d78SAlex Zinenko //
968ae0d78SAlex Zinenko // This file defines Transform dialect extension operations used in the
1068ae0d78SAlex Zinenko // Chapter 2 of the Transform dialect tutorial.
1168ae0d78SAlex Zinenko //
1268ae0d78SAlex Zinenko //===----------------------------------------------------------------------===//
1368ae0d78SAlex Zinenko 
1468ae0d78SAlex Zinenko #include "MyExtension.h"
1568ae0d78SAlex Zinenko #include "mlir/Dialect/Func/IR/FuncOps.h"
1668ae0d78SAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
1768ae0d78SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18ec6da065SMehdi Amini #include "mlir/Dialect/Transform/IR/TransformTypes.h"
195a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20ec6da065SMehdi Amini #include "mlir/IR/DialectRegistry.h"
21ec6da065SMehdi Amini #include "mlir/IR/Operation.h"
22ec6da065SMehdi Amini #include "mlir/Interfaces/SideEffectInterfaces.h"
23ec6da065SMehdi Amini #include "mlir/Support/LLVM.h"
24ec6da065SMehdi Amini #include "llvm/ADT/SmallVector.h"
25ec6da065SMehdi Amini #include "llvm/ADT/StringRef.h"
2668ae0d78SAlex Zinenko 
2768ae0d78SAlex Zinenko // Define a new transform dialect extension. This uses the CRTP idiom to
2868ae0d78SAlex Zinenko // identify extensions.
2968ae0d78SAlex Zinenko class MyExtension
3068ae0d78SAlex Zinenko     : public ::mlir::transform::TransformDialectExtension<MyExtension> {
3168ae0d78SAlex Zinenko public:
32*84cc1865SNikhil Kalra   // The TypeID of this extension.
33*84cc1865SNikhil Kalra   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
34*84cc1865SNikhil Kalra 
3568ae0d78SAlex Zinenko   // The extension must derive the base constructor.
3668ae0d78SAlex Zinenko   using Base::Base;
3768ae0d78SAlex Zinenko 
3868ae0d78SAlex Zinenko   // This function initializes the extension, similarly to `initialize` in
3968ae0d78SAlex Zinenko   // dialect definitions. List individual operations and dependent dialects
4068ae0d78SAlex Zinenko   // here.
4168ae0d78SAlex Zinenko   void init();
4268ae0d78SAlex Zinenko };
4368ae0d78SAlex Zinenko 
4468ae0d78SAlex Zinenko void MyExtension::init() {
4568ae0d78SAlex Zinenko   // Similarly to dialects, an extension can declare a dependent dialect. This
4668ae0d78SAlex Zinenko   // dialect will be loaded along with the extension and, therefore, along with
4768ae0d78SAlex Zinenko   // the Transform dialect. Only declare as dependent the dialects that contain
4868ae0d78SAlex Zinenko   // the attributes or types used by transform operations. Do NOT declare as
4968ae0d78SAlex Zinenko   // dependent the dialects produced during the transformation.
5068ae0d78SAlex Zinenko   // declareDependentDialect<MyDialect>();
5168ae0d78SAlex Zinenko 
5268ae0d78SAlex Zinenko   // When transformations are applied, they may produce new operations from
5368ae0d78SAlex Zinenko   // previously unloaded dialects. Typically, a pass would need to declare
5468ae0d78SAlex Zinenko   // itself dependent on the dialects containing such new operations. To avoid
5568ae0d78SAlex Zinenko   // confusion with the dialects the extension itself depends on, the Transform
5668ae0d78SAlex Zinenko   // dialects differentiates between:
5768ae0d78SAlex Zinenko   //   - dependent dialects, which are used by the transform operations, and
5868ae0d78SAlex Zinenko   //   - generated dialects, which contain the entities (attributes, operations,
5968ae0d78SAlex Zinenko   //     types) that may be produced by applying the transformation even when
6068ae0d78SAlex Zinenko   //     not present in the original payload IR.
6168ae0d78SAlex Zinenko   // In the following chapter, we will be add operations that generate function
6268ae0d78SAlex Zinenko   // calls and structured control flow operations, so let's declare the
6368ae0d78SAlex Zinenko   // corresponding dialects as generated.
6468ae0d78SAlex Zinenko   declareGeneratedDialect<::mlir::scf::SCFDialect>();
6568ae0d78SAlex Zinenko   declareGeneratedDialect<::mlir::func::FuncDialect>();
6668ae0d78SAlex Zinenko 
6768ae0d78SAlex Zinenko   // Finally, we register the additional transform operations with the dialect.
6868ae0d78SAlex Zinenko   // List all operations generated from ODS. This call will perform additional
6968ae0d78SAlex Zinenko   // checks that the operations implement the transform and memory effect
7068ae0d78SAlex Zinenko   // interfaces required by the dialect interpreter and assert if they do not.
7168ae0d78SAlex Zinenko   registerTransformOps<
7268ae0d78SAlex Zinenko #define GET_OP_LIST
7368ae0d78SAlex Zinenko #include "MyExtension.cpp.inc"
7468ae0d78SAlex Zinenko       >();
7568ae0d78SAlex Zinenko }
7668ae0d78SAlex Zinenko 
7768ae0d78SAlex Zinenko #define GET_OP_CLASSES
7868ae0d78SAlex Zinenko #include "MyExtension.cpp.inc"
7968ae0d78SAlex Zinenko 
8068ae0d78SAlex Zinenko static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) {
8168ae0d78SAlex Zinenko   call.setCallee(newTarget);
8268ae0d78SAlex Zinenko }
8368ae0d78SAlex Zinenko 
8468ae0d78SAlex Zinenko // Implementation of our transform dialect operation.
8568ae0d78SAlex Zinenko // This operation returns a tri-state result that can be one of:
8668ae0d78SAlex Zinenko // - success when the transformation succeeded;
8768ae0d78SAlex Zinenko // - definite failure when the transformation failed in such a way that
88c63d2b2cSMatthias Springer //   following transformations are impossible or undesirable, typically it could
89c63d2b2cSMatthias Springer //   have left payload IR in an invalid state; it is expected that a diagnostic
90c63d2b2cSMatthias Springer //   is emitted immediately before returning the definite error;
9168ae0d78SAlex Zinenko // - silenceable failure when the transformation failed but following
92c63d2b2cSMatthias Springer //   transformations are still applicable, typically this means a precondition
93c63d2b2cSMatthias Springer //   for the transformation is not satisfied and the payload IR has not been
94c63d2b2cSMatthias Springer //   modified. The silenceable failure additionally carries a Diagnostic that
95c63d2b2cSMatthias Springer //   can be emitted to the user.
9668ae0d78SAlex Zinenko ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply(
97c63d2b2cSMatthias Springer     // The rewriter that should be used when modifying IR.
98c63d2b2cSMatthias Springer     ::mlir::transform::TransformRewriter &rewriter,
9968ae0d78SAlex Zinenko     // The list of payload IR entities that will be associated with the
10068ae0d78SAlex Zinenko     // transform IR values defined by this transform operation. In this case, it
10168ae0d78SAlex Zinenko     // can remain empty as there are no results.
10268ae0d78SAlex Zinenko     ::mlir::transform::TransformResults &results,
10368ae0d78SAlex Zinenko     // The transform application state. This object can be used to query the
10468ae0d78SAlex Zinenko     // current associations between transform IR values and payload IR entities.
10568ae0d78SAlex Zinenko     // It can also carry additional user-defined state.
10668ae0d78SAlex Zinenko     ::mlir::transform::TransformState &state) {
10768ae0d78SAlex Zinenko 
10868ae0d78SAlex Zinenko   // First, we need to obtain the list of payload operations that are associated
10968ae0d78SAlex Zinenko   // with the operand handle.
11068ae0d78SAlex Zinenko   auto payload = state.getPayloadOps(getCall());
11168ae0d78SAlex Zinenko 
11268ae0d78SAlex Zinenko   // Then, we iterate over the list of operands and call the actual IR-mutating
11368ae0d78SAlex Zinenko   // function. We also check the preconditions here.
11468ae0d78SAlex Zinenko   for (Operation *payloadOp : payload) {
11568ae0d78SAlex Zinenko     auto call = dyn_cast<::mlir::func::CallOp>(payloadOp);
11668ae0d78SAlex Zinenko     if (!call) {
11768ae0d78SAlex Zinenko       DiagnosedSilenceableFailure diag =
11868ae0d78SAlex Zinenko           emitSilenceableError() << "only applies to func.call payloads";
11968ae0d78SAlex Zinenko       diag.attachNote(payloadOp->getLoc()) << "offending payload";
12068ae0d78SAlex Zinenko       return diag;
12168ae0d78SAlex Zinenko     }
12268ae0d78SAlex Zinenko 
12368ae0d78SAlex Zinenko     updateCallee(call, getNewTarget());
12468ae0d78SAlex Zinenko   }
12568ae0d78SAlex Zinenko 
12668ae0d78SAlex Zinenko   // If everything went well, return success.
12768ae0d78SAlex Zinenko   return DiagnosedSilenceableFailure::success();
12868ae0d78SAlex Zinenko }
12968ae0d78SAlex Zinenko 
13068ae0d78SAlex Zinenko void mlir::transform::ChangeCallTargetOp::getEffects(
13168ae0d78SAlex Zinenko     ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
13268ae0d78SAlex Zinenko   // Indicate that the `call` handle is only read by this operation because the
13368ae0d78SAlex Zinenko   // associated operation is not erased but rather modified in-place, so the
13468ae0d78SAlex Zinenko   // reference to it remains valid.
1352c1ae801Sdonald chen   onlyReadsHandle(getCallMutable(), effects);
13668ae0d78SAlex Zinenko 
13768ae0d78SAlex Zinenko   // Indicate that the payload is modified by this operation.
13868ae0d78SAlex Zinenko   modifiesPayload(effects);
13968ae0d78SAlex Zinenko }
14068ae0d78SAlex Zinenko 
14168ae0d78SAlex Zinenko void registerMyExtension(::mlir::DialectRegistry &registry) {
14268ae0d78SAlex Zinenko   registry.addExtensions<MyExtension>();
14368ae0d78SAlex Zinenko }
144