xref: /llvm-project/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1//===- TransformDialect.td - Transform dialect definition --*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT
10#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT
11
12include "mlir/IR/OpBase.td"
13
14def Transform_Dialect : Dialect {
15  let summary = "Fine-grain transformation control dialect";
16  // For description, see docs/Dialects/Transform.md.
17
18  let name = "transform";
19  let cppNamespace = "::mlir::transform";
20
21  let hasOperationAttrVerify = 1;
22  let extraClassDeclaration = [{
23    /// Symbol name for the default entry point "named sequence".
24    constexpr const static ::llvm::StringLiteral
25        kTransformEntryPointSymbolName = "__transform_main";
26
27    /// Name of the attribute attachable to the symbol table operation
28    /// containing named sequences. This is used to trigger verification.
29    constexpr const static ::llvm::StringLiteral
30        kWithNamedSequenceAttrName = "transform.with_named_sequence";
31
32    /// Name of the attribute attachable to an operation so it can be
33    /// identified as root by the default interpreter pass.
34    constexpr const static ::llvm::StringLiteral kTargetTagAttrName =
35        "transform.target_tag";
36
37    /// Names of the attributes indicating whether an argument of an external
38    /// transform dialect symbol is consumed or only read.
39    constexpr const static ::llvm::StringLiteral kArgConsumedAttrName =
40        "transform.consumed";
41    constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName =
42        "transform.readonly";
43
44    /// Names of the attributes indicating whether an argument of an external
45    /// transform dialect symbol is consumed or only read.
46    StringAttr getConsumedAttrName() const {
47      return StringAttr::get(getContext(), kArgConsumedAttrName);
48    }
49    StringAttr getReadOnlyAttrName() const {
50      return StringAttr::get(getContext(), kArgReadOnlyAttrName);
51    }
52
53    template <typename DataTy>
54    const DataTy &getExtraData() const {
55      return *static_cast<const DataTy *>(
56          extraData.at(::mlir::TypeID::get<DataTy>()).get());
57    }
58
59    /// Parses a type registered by this dialect or one of its extensions.
60    ::mlir::Type parseType(::mlir::DialectAsmParser & parser) const override;
61
62    /// Prints a type registered by this dialect or one of its extensions.
63    void printType(::mlir::Type type, ::mlir::DialectAsmPrinter & printer)
64        const override;
65
66    /// Parser callback for an individual type registered by this dialect or
67    /// its extensions.
68    using ExtensionTypeParsingHook = ::mlir::Type (*)(::mlir::AsmParser &);
69
70    /// Printer callback for an individual type registered by this dialect or
71    /// its extensions.
72    using ExtensionTypePrintingHook =
73        std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
74
75    /// Loads the given module into the transform symbol library module.
76    LogicalResult loadIntoLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
77                                        library);
78
79    /// Returns the transform symbol library module available to all dialect
80    /// users.
81    ModuleOp getLibraryModule() const {
82      if (libraryModule)
83        return libraryModule.get();
84      return ModuleOp();
85    }
86
87  private:
88    /// Initializes the transform symbol library module. Must be called from
89    /// `TransformDialect::initialize` for the library module to work.
90    void initializeLibraryModule();
91
92    /// Registers operations specified as template parameters with this
93    /// dialect. Checks that they implement the required interfaces.
94    template <typename... OpTys>
95    void addOperationsChecked() {
96      (addOperationIfNotRegistered<OpTys>(), ...);
97    }
98    template <typename OpTy>
99    void addOperationIfNotRegistered();
100
101    /// Reports a repeated registration error of an op with the given name.
102    [[noreturn]] void reportDuplicateOpRegistration(StringRef opName);
103
104    /// Registers types specified as template parameters with the Transform
105    /// dialect. Checks that they meet the requirements for Transform IR types.
106    template <typename... TypeTys>
107    void addTypesChecked() {
108      (addTypeIfNotRegistered<TypeTys>(), ...);
109    }
110    template <typename Type>
111    void addTypeIfNotRegistered();
112
113    /// Reports a repeated registration error of a type with the given
114    /// mnemonic.
115    [[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);
116
117    /// Registers dialect types with the context.
118    void initializeTypes();
119
120    // Give extensions access to injection functions.
121    template <typename, typename...>
122    friend class TransformDialectExtension;
123
124    /// Gets a mutable reference to extra data of the kind specified as
125    /// template argument. Allocates the data on the first call.
126    template <typename DataTy>
127    DataTy &getOrCreateExtraData();
128
129    //===----------------------------------------------------------------===//
130    // Data fields
131    //===----------------------------------------------------------------===//
132
133    /// Additional data associated with and owned by the dialect. Accessible
134    /// to extensions.
135    ::llvm::DenseMap<
136        ::mlir::TypeID,
137        std::unique_ptr<::mlir::transform::detail::TransformDialectDataBase>>
138        extraData;
139
140    /// A map from type mnemonic to its parsing function for the remainder of
141    /// the syntax. The parser has access to the mnemonic, so it is used for
142    /// further dispatch.
143    ::llvm::StringMap<ExtensionTypeParsingHook> typeParsingHooks;
144
145    /// A map from type TypeID to its printing function. No need to do string
146    /// lookups when the type is fully constructed.
147    ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
148        typePrintingHooks;
149
150    /// Module containing symbols, e.g. named sequences, that will be resolved
151    /// by the interpreter when used.
152    ::mlir::OwningOpRef<::mlir::ModuleOp> libraryModule;
153  }];
154}
155
156// Base class for ops that belong to the transform dialect. Ops defined in
157// extensions of this dialect may also use this.
158class TransformDialectOp<string mnemonic, list<Trait> traits = []>
159    : Op<Transform_Dialect, mnemonic, traits>;
160
161// Trait for operations that may be top-level operations in Transform IR.
162// Operations must have one single-block region and must be usable without
163// operands. See the C++ definition of the trait for more information.
164def PossibleTopLevelTransformOpTrait
165    : NativeOpTrait<"PossibleTopLevelTransformOpTrait"> {
166  let cppNamespace = "::mlir::transform";
167}
168
169#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT
170