xref: /llvm-project/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp (revision 123e8c735d0765a12e65f1daefcbe23a059e26fd)
1 //===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
10 #include "mlir/Analysis/CallGraph.h"
11 #include "mlir/Dialect/Transform/IR/TransformOps.h"
12 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
13 #include "mlir/Dialect/Transform/IR/Utils.h"
14 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
15 #include "mlir/IR/DialectImplementation.h"
16 #include "llvm/ADT/SCCIterator.h"
17 
18 using namespace mlir;
19 
20 #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
21 
22 #ifndef NDEBUG
23 void transform::detail::checkImplementsTransformOpInterface(
24     StringRef name, MLIRContext *context) {
25   // Since the operation is being inserted into the Transform dialect and the
26   // dialect does not implement the interface fallback, only check for the op
27   // itself having the interface implementation.
28   RegisteredOperationName opName =
29       *RegisteredOperationName::lookup(name, context);
30   assert((opName.hasInterface<TransformOpInterface>() ||
31           opName.hasInterface<PatternDescriptorOpInterface>() ||
32           opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
33           opName.hasInterface<TypeConverterBuilderOpInterface>() ||
34           opName.hasTrait<OpTrait::IsTerminator>()) &&
35          "non-terminator ops injected into the transform dialect must "
36          "implement TransformOpInterface or PatternDescriptorOpInterface or "
37          "ConversionPatternDescriptorOpInterface");
38   if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
39       !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
40       !opName.hasInterface<TypeConverterBuilderOpInterface>()) {
41     assert(opName.hasInterface<MemoryEffectOpInterface>() &&
42            "ops injected into the transform dialect must implement "
43            "MemoryEffectsOpInterface");
44   }
45 }
46 
47 void transform::detail::checkImplementsTransformHandleTypeInterface(
48     TypeID typeID, MLIRContext *context) {
49   const auto &abstractType = AbstractType::lookup(typeID, context);
50   assert((abstractType.hasInterface(
51               TransformHandleTypeInterface::getInterfaceID()) ||
52           abstractType.hasInterface(
53               TransformParamTypeInterface::getInterfaceID()) ||
54           abstractType.hasInterface(
55               TransformValueHandleTypeInterface::getInterfaceID())) &&
56          "expected Transform dialect type to implement one of the three "
57          "interfaces");
58 }
59 #endif // NDEBUG
60 
61 void transform::TransformDialect::initialize() {
62   // Using the checked versions to enable the same assertions as for the ops
63   // from extensions.
64   addOperationsChecked<
65 #define GET_OP_LIST
66 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
67       >();
68   initializeTypes();
69   initializeLibraryModule();
70 }
71 
72 Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
73   StringRef keyword;
74   SMLoc loc = parser.getCurrentLocation();
75   if (failed(parser.parseKeyword(&keyword)))
76     return nullptr;
77 
78   auto it = typeParsingHooks.find(keyword);
79   if (it == typeParsingHooks.end()) {
80     parser.emitError(loc) << "unknown type mnemonic: " << keyword;
81     return nullptr;
82   }
83 
84   return it->getValue()(parser);
85 }
86 
87 void transform::TransformDialect::printType(Type type,
88                                             DialectAsmPrinter &printer) const {
89   auto it = typePrintingHooks.find(type.getTypeID());
90   assert(it != typePrintingHooks.end() && "printing unknown type");
91   it->getSecond()(type, printer);
92 }
93 
94 LogicalResult transform::TransformDialect::loadIntoLibraryModule(
95     ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
96   return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
97 }
98 
99 void transform::TransformDialect::initializeLibraryModule() {
100   MLIRContext *context = getContext();
101   auto loc =
102       FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
103   libraryModule = ModuleOp::create(loc, "__transform_library");
104   libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
105                                UnitAttr::get(context));
106 }
107 
108 void transform::TransformDialect::reportDuplicateTypeRegistration(
109     StringRef mnemonic) {
110   std::string buffer;
111   llvm::raw_string_ostream msg(buffer);
112   msg << "extensible dialect type '" << mnemonic
113       << "' is already registered with a different implementation";
114   llvm::report_fatal_error(StringRef(buffer));
115 }
116 
117 void transform::TransformDialect::reportDuplicateOpRegistration(
118     StringRef opName) {
119   std::string buffer;
120   llvm::raw_string_ostream msg(buffer);
121   msg << "extensible dialect operation '" << opName
122       << "' is already registered with a mismatching TypeID";
123   llvm::report_fatal_error(StringRef(buffer));
124 }
125 
126 LogicalResult transform::TransformDialect::verifyOperationAttribute(
127     Operation *op, NamedAttribute attribute) {
128   if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
129     if (!op->hasTrait<OpTrait::SymbolTable>()) {
130       return emitError(op->getLoc()) << attribute.getName()
131                                      << " attribute can only be attached to "
132                                         "operations with symbol tables";
133     }
134 
135     const mlir::CallGraph callgraph(op);
136     for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
137       if (!scc.hasCycle())
138         continue;
139 
140       // Need to check this here additionally because this verification may run
141       // before we check the nested operations.
142       if ((*scc->begin())->isExternal())
143         return op->emitOpError() << "contains a call to an external operation, "
144                                     "which is not allowed";
145 
146       Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
147       InFlightDiagnostic diag = emitError(first->getLoc())
148                                 << "recursion not allowed in named sequences";
149       for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
150         // Need to check this here additionally because this verification may
151         // run before we check the nested operations.
152         if ((*it)->isExternal()) {
153           return op->emitOpError() << "contains a call to an external "
154                                       "operation, which is not allowed";
155         }
156 
157         Operation *current = (*it)->getCallableRegion()->getParentOp();
158         diag.attachNote(current->getLoc()) << "operation on recursion stack";
159       }
160       return diag;
161     }
162     return success();
163   }
164   if (attribute.getName().getValue() == kTargetTagAttrName) {
165     if (!llvm::isa<StringAttr>(attribute.getValue())) {
166       return op->emitError()
167              << attribute.getName() << " attribute must be a string";
168     }
169     return success();
170   }
171   if (attribute.getName().getValue() == kArgConsumedAttrName ||
172       attribute.getName().getValue() == kArgReadOnlyAttrName) {
173     if (!llvm::isa<UnitAttr>(attribute.getValue())) {
174       return op->emitError()
175              << attribute.getName() << " must be a unit attribute";
176     }
177     return success();
178   }
179   if (attribute.getName().getValue() ==
180       FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
181     if (!llvm::isa<UnitAttr>(attribute.getValue())) {
182       return op->emitError()
183              << attribute.getName() << " must be a unit attribute";
184     }
185     return success();
186   }
187   return emitError(op->getLoc())
188          << "unknown attribute: " << attribute.getName();
189 }
190