xref: /llvm-project/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp (revision 123e8c735d0765a12e65f1daefcbe23a059e26fd)
1d064c480SAlex Zinenko //===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
2d064c480SAlex Zinenko //
3d064c480SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d064c480SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5d064c480SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d064c480SAlex Zinenko //
7d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
8d064c480SAlex Zinenko 
9d064c480SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h"
10fb409a28SAlex Zinenko #include "mlir/Analysis/CallGraph.h"
110eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformOps.h"
12bba85ebdSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformTypes.h"
1399c15eb4SIngo Müller #include "mlir/Dialect/Transform/IR/Utils.h"
14*5a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
15bba85ebdSAlex Zinenko #include "mlir/IR/DialectImplementation.h"
16fb409a28SAlex Zinenko #include "llvm/ADT/SCCIterator.h"
17d064c480SAlex Zinenko 
18d064c480SAlex Zinenko using namespace mlir;
19d064c480SAlex Zinenko 
200eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
210eb403adSAlex Zinenko 
22bba85ebdSAlex Zinenko #ifndef NDEBUG
23b586d56cSAlex Zinenko void transform::detail::checkImplementsTransformOpInterface(
24b586d56cSAlex Zinenko     StringRef name, MLIRContext *context) {
25b586d56cSAlex Zinenko   // Since the operation is being inserted into the Transform dialect and the
26b586d56cSAlex Zinenko   // dialect does not implement the interface fallback, only check for the op
27b586d56cSAlex Zinenko   // itself having the interface implementation.
28b586d56cSAlex Zinenko   RegisteredOperationName opName =
29b586d56cSAlex Zinenko       *RegisteredOperationName::lookup(name, context);
30b586d56cSAlex Zinenko   assert((opName.hasInterface<TransformOpInterface>() ||
315a10f207SMatthias Springer           opName.hasInterface<PatternDescriptorOpInterface>() ||
32bcfdb3e4SMatthias Springer           opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
33bcfdb3e4SMatthias Springer           opName.hasInterface<TypeConverterBuilderOpInterface>() ||
34b586d56cSAlex Zinenko           opName.hasTrait<OpTrait::IsTerminator>()) &&
35b586d56cSAlex Zinenko          "non-terminator ops injected into the transform dialect must "
36bcfdb3e4SMatthias Springer          "implement TransformOpInterface or PatternDescriptorOpInterface or "
37bcfdb3e4SMatthias Springer          "ConversionPatternDescriptorOpInterface");
38bcfdb3e4SMatthias Springer   if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
39bcfdb3e4SMatthias Springer       !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
40bcfdb3e4SMatthias Springer       !opName.hasInterface<TypeConverterBuilderOpInterface>()) {
41b586d56cSAlex Zinenko     assert(opName.hasInterface<MemoryEffectOpInterface>() &&
42b586d56cSAlex Zinenko            "ops injected into the transform dialect must implement "
43b586d56cSAlex Zinenko            "MemoryEffectsOpInterface");
44b586d56cSAlex Zinenko   }
455a10f207SMatthias Springer }
46b586d56cSAlex Zinenko 
4797c05062SAlex Zinenko void transform::detail::checkImplementsTransformHandleTypeInterface(
48bba85ebdSAlex Zinenko     TypeID typeID, MLIRContext *context) {
49bba85ebdSAlex Zinenko   const auto &abstractType = AbstractType::lookup(typeID, context);
50a7026288SAlex Zinenko   assert((abstractType.hasInterface(
5197c05062SAlex Zinenko               TransformHandleTypeInterface::getInterfaceID()) ||
52ed02fa81SAlex Zinenko           abstractType.hasInterface(
53a7026288SAlex Zinenko               TransformParamTypeInterface::getInterfaceID()) ||
54a7026288SAlex Zinenko           abstractType.hasInterface(
55a7026288SAlex Zinenko               TransformValueHandleTypeInterface::getInterfaceID())) &&
56a7026288SAlex Zinenko          "expected Transform dialect type to implement one of the three "
57a7026288SAlex Zinenko          "interfaces");
58bba85ebdSAlex Zinenko }
59bba85ebdSAlex Zinenko #endif // NDEBUG
60bba85ebdSAlex Zinenko 
610eb403adSAlex Zinenko void transform::TransformDialect::initialize() {
62bba85ebdSAlex Zinenko   // Using the checked versions to enable the same assertions as for the ops
63bba85ebdSAlex Zinenko   // from extensions.
6440a8bd63SAlex Zinenko   addOperationsChecked<
650eb403adSAlex Zinenko #define GET_OP_LIST
660eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
670eb403adSAlex Zinenko       >();
683e1f6d02SAlex Zinenko   initializeTypes();
6999c15eb4SIngo Müller   initializeLibraryModule();
7030f22429SAlex Zinenko }
71a60ed954SAlex Zinenko 
72bba85ebdSAlex Zinenko Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
73bba85ebdSAlex Zinenko   StringRef keyword;
74bba85ebdSAlex Zinenko   SMLoc loc = parser.getCurrentLocation();
75bba85ebdSAlex Zinenko   if (failed(parser.parseKeyword(&keyword)))
76bba85ebdSAlex Zinenko     return nullptr;
77bba85ebdSAlex Zinenko 
78bba85ebdSAlex Zinenko   auto it = typeParsingHooks.find(keyword);
79bba85ebdSAlex Zinenko   if (it == typeParsingHooks.end()) {
80bba85ebdSAlex Zinenko     parser.emitError(loc) << "unknown type mnemonic: " << keyword;
81bba85ebdSAlex Zinenko     return nullptr;
82bba85ebdSAlex Zinenko   }
83bba85ebdSAlex Zinenko 
84bba85ebdSAlex Zinenko   return it->getValue()(parser);
85bba85ebdSAlex Zinenko }
86bba85ebdSAlex Zinenko 
87bba85ebdSAlex Zinenko void transform::TransformDialect::printType(Type type,
88bba85ebdSAlex Zinenko                                             DialectAsmPrinter &printer) const {
89bba85ebdSAlex Zinenko   auto it = typePrintingHooks.find(type.getTypeID());
90bba85ebdSAlex Zinenko   assert(it != typePrintingHooks.end() && "printing unknown type");
91bba85ebdSAlex Zinenko   it->getSecond()(type, printer);
92bba85ebdSAlex Zinenko }
93bba85ebdSAlex Zinenko 
9499c15eb4SIngo Müller LogicalResult transform::TransformDialect::loadIntoLibraryModule(
9599c15eb4SIngo Müller     ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
9699c15eb4SIngo Müller   return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
9799c15eb4SIngo Müller }
9899c15eb4SIngo Müller 
9999c15eb4SIngo Müller void transform::TransformDialect::initializeLibraryModule() {
10099c15eb4SIngo Müller   MLIRContext *context = getContext();
10199c15eb4SIngo Müller   auto loc =
10299c15eb4SIngo Müller       FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
10399c15eb4SIngo Müller   libraryModule = ModuleOp::create(loc, "__transform_library");
10499c15eb4SIngo Müller   libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
10599c15eb4SIngo Müller                                UnitAttr::get(context));
10699c15eb4SIngo Müller }
10799c15eb4SIngo Müller 
108bba85ebdSAlex Zinenko void transform::TransformDialect::reportDuplicateTypeRegistration(
109bba85ebdSAlex Zinenko     StringRef mnemonic) {
110bba85ebdSAlex Zinenko   std::string buffer;
111bba85ebdSAlex Zinenko   llvm::raw_string_ostream msg(buffer);
112b586d56cSAlex Zinenko   msg << "extensible dialect type '" << mnemonic
113bba85ebdSAlex Zinenko       << "' is already registered with a different implementation";
114bba85ebdSAlex Zinenko   llvm::report_fatal_error(StringRef(buffer));
115bba85ebdSAlex Zinenko }
116bba85ebdSAlex Zinenko 
117b586d56cSAlex Zinenko void transform::TransformDialect::reportDuplicateOpRegistration(
118b586d56cSAlex Zinenko     StringRef opName) {
119b586d56cSAlex Zinenko   std::string buffer;
120b586d56cSAlex Zinenko   llvm::raw_string_ostream msg(buffer);
121b586d56cSAlex Zinenko   msg << "extensible dialect operation '" << opName
122b586d56cSAlex Zinenko       << "' is already registered with a mismatching TypeID";
123b586d56cSAlex Zinenko   llvm::report_fatal_error(StringRef(buffer));
124b586d56cSAlex Zinenko }
125b586d56cSAlex Zinenko 
126fb409a28SAlex Zinenko LogicalResult transform::TransformDialect::verifyOperationAttribute(
127fb409a28SAlex Zinenko     Operation *op, NamedAttribute attribute) {
128fb409a28SAlex Zinenko   if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
129fb409a28SAlex Zinenko     if (!op->hasTrait<OpTrait::SymbolTable>()) {
130fb409a28SAlex Zinenko       return emitError(op->getLoc()) << attribute.getName()
131fb409a28SAlex Zinenko                                      << " attribute can only be attached to "
132fb409a28SAlex Zinenko                                         "operations with symbol tables";
133fb409a28SAlex Zinenko     }
134fb409a28SAlex Zinenko 
135fb409a28SAlex Zinenko     const mlir::CallGraph callgraph(op);
136fb409a28SAlex Zinenko     for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
137fb409a28SAlex Zinenko       if (!scc.hasCycle())
138fb409a28SAlex Zinenko         continue;
139fb409a28SAlex Zinenko 
140fb409a28SAlex Zinenko       // Need to check this here additionally because this verification may run
141fb409a28SAlex Zinenko       // before we check the nested operations.
142fb409a28SAlex Zinenko       if ((*scc->begin())->isExternal())
143fb409a28SAlex Zinenko         return op->emitOpError() << "contains a call to an external operation, "
144fb409a28SAlex Zinenko                                     "which is not allowed";
145fb409a28SAlex Zinenko 
146fb409a28SAlex Zinenko       Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
147fb409a28SAlex Zinenko       InFlightDiagnostic diag = emitError(first->getLoc())
148fb409a28SAlex Zinenko                                 << "recursion not allowed in named sequences";
149fb409a28SAlex Zinenko       for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
150fb409a28SAlex Zinenko         // Need to check this here additionally because this verification may
151fb409a28SAlex Zinenko         // run before we check the nested operations.
152fb409a28SAlex Zinenko         if ((*it)->isExternal()) {
153fb409a28SAlex Zinenko           return op->emitOpError() << "contains a call to an external "
154fb409a28SAlex Zinenko                                       "operation, which is not allowed";
155fb409a28SAlex Zinenko         }
156fb409a28SAlex Zinenko 
157fb409a28SAlex Zinenko         Operation *current = (*it)->getCallableRegion()->getParentOp();
158fb409a28SAlex Zinenko         diag.attachNote(current->getLoc()) << "operation on recursion stack";
159fb409a28SAlex Zinenko       }
160fb409a28SAlex Zinenko       return diag;
161fb409a28SAlex Zinenko     }
162fb409a28SAlex Zinenko     return success();
163fb409a28SAlex Zinenko   }
164fb409a28SAlex Zinenko   if (attribute.getName().getValue() == kTargetTagAttrName) {
165c1fa60b4STres Popp     if (!llvm::isa<StringAttr>(attribute.getValue())) {
166fb409a28SAlex Zinenko       return op->emitError()
167fb409a28SAlex Zinenko              << attribute.getName() << " attribute must be a string";
168fb409a28SAlex Zinenko     }
169fb409a28SAlex Zinenko     return success();
170fb409a28SAlex Zinenko   }
17141109341SAlex Zinenko   if (attribute.getName().getValue() == kArgConsumedAttrName ||
17241109341SAlex Zinenko       attribute.getName().getValue() == kArgReadOnlyAttrName) {
173c1fa60b4STres Popp     if (!llvm::isa<UnitAttr>(attribute.getValue())) {
17441109341SAlex Zinenko       return op->emitError()
17541109341SAlex Zinenko              << attribute.getName() << " must be a unit attribute";
17641109341SAlex Zinenko     }
17741109341SAlex Zinenko     return success();
17841109341SAlex Zinenko   }
179*5a9bdd85SOleksandr "Alex" Zinenko   if (attribute.getName().getValue() ==
180*5a9bdd85SOleksandr "Alex" Zinenko       FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
181c63d2b2cSMatthias Springer     if (!llvm::isa<UnitAttr>(attribute.getValue())) {
182c63d2b2cSMatthias Springer       return op->emitError()
183c63d2b2cSMatthias Springer              << attribute.getName() << " must be a unit attribute";
184c63d2b2cSMatthias Springer     }
185c63d2b2cSMatthias Springer     return success();
186c63d2b2cSMatthias Springer   }
187fb409a28SAlex Zinenko   return emitError(op->getLoc())
188fb409a28SAlex Zinenko          << "unknown attribute: " << attribute.getName();
189fb409a28SAlex Zinenko }
190