xref: /llvm-project/mlir/lib/Tools/PDLL/ODS/Context.cpp (revision a1fe1f5f77d48b03b76884a9b9b91a6795193ac1)
1 //===- Context.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/ODS/Context.h"
10 #include "mlir/Tools/PDLL/ODS/Constraint.h"
11 #include "mlir/Tools/PDLL/ODS/Dialect.h"
12 #include "mlir/Tools/PDLL/ODS/Operation.h"
13 #include "llvm/Support/ScopedPrinter.h"
14 #include "llvm/Support/raw_ostream.h"
15 #include <optional>
16 
17 using namespace mlir;
18 using namespace mlir::pdll::ods;
19 
20 //===----------------------------------------------------------------------===//
21 // Context
22 //===----------------------------------------------------------------------===//
23 
24 Context::Context() = default;
25 Context::~Context() = default;
26 
27 const AttributeConstraint &
insertAttributeConstraint(StringRef name,StringRef summary,StringRef cppClass)28 Context::insertAttributeConstraint(StringRef name, StringRef summary,
29                                    StringRef cppClass) {
30   std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name];
31   if (!constraint) {
32     constraint.reset(new AttributeConstraint(name, summary, cppClass));
33   } else {
34     assert(constraint->getCppClass() == cppClass &&
35            constraint->getSummary() == summary &&
36            "constraint with the same name was already registered with a "
37            "different class");
38   }
39   return *constraint;
40 }
41 
insertTypeConstraint(StringRef name,StringRef summary,StringRef cppClass)42 const TypeConstraint &Context::insertTypeConstraint(StringRef name,
43                                                     StringRef summary,
44                                                     StringRef cppClass) {
45   std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name];
46   if (!constraint)
47     constraint.reset(new TypeConstraint(name, summary, cppClass));
48   return *constraint;
49 }
50 
insertDialect(StringRef name)51 Dialect &Context::insertDialect(StringRef name) {
52   std::unique_ptr<Dialect> &dialect = dialects[name];
53   if (!dialect)
54     dialect.reset(new Dialect(name));
55   return *dialect;
56 }
57 
lookupDialect(StringRef name) const58 const Dialect *Context::lookupDialect(StringRef name) const {
59   auto it = dialects.find(name);
60   return it == dialects.end() ? nullptr : &*it->second;
61 }
62 
63 std::pair<Operation *, bool>
insertOperation(StringRef name,StringRef summary,StringRef desc,StringRef nativeClassName,bool supportsResultTypeInferrence,SMLoc loc)64 Context::insertOperation(StringRef name, StringRef summary, StringRef desc,
65                          StringRef nativeClassName,
66                          bool supportsResultTypeInferrence, SMLoc loc) {
67   std::pair<StringRef, StringRef> dialectAndName = name.split('.');
68   return insertDialect(dialectAndName.first)
69       .insertOperation(name, summary, desc, nativeClassName,
70                        supportsResultTypeInferrence, loc);
71 }
72 
lookupOperation(StringRef name) const73 const Operation *Context::lookupOperation(StringRef name) const {
74   std::pair<StringRef, StringRef> dialectAndName = name.split('.');
75   if (const Dialect *dialect = lookupDialect(dialectAndName.first))
76     return dialect->lookupOperation(name);
77   return nullptr;
78 }
79 
80 template <typename T>
sortMapByName(const llvm::StringMap<std::unique_ptr<T>> & map)81 SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) {
82   SmallVector<T *> storage;
83   for (auto &entry : map)
84     storage.push_back(entry.second.get());
85   llvm::sort(storage, [](const auto &lhs, const auto &rhs) {
86     return lhs->getName() < rhs->getName();
87   });
88   return storage;
89 }
90 
print(raw_ostream & os) const91 void Context::print(raw_ostream &os) const {
92   auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) {
93     switch (kind) {
94     case VariableLengthKind::Optional:
95       os << "Optional<" << cst << ">";
96       break;
97     case VariableLengthKind::Single:
98       os << cst;
99       break;
100     case VariableLengthKind::Variadic:
101       os << "Variadic<" << cst << ">";
102       break;
103     }
104   };
105 
106   llvm::ScopedPrinter printer(os);
107   llvm::DictScope odsScope(printer, "ODSContext");
108   for (const Dialect *dialect : sortMapByName(dialects)) {
109     printer.startLine() << "Dialect `" << dialect->getName() << "` {\n";
110     printer.indent();
111 
112     for (const Operation *op : sortMapByName(dialect->getOperations())) {
113       printer.startLine() << "Operation `" << op->getName() << "` {\n";
114       printer.indent();
115 
116       // Attributes.
117       ArrayRef<Attribute> attributes = op->getAttributes();
118       if (!attributes.empty()) {
119         printer.startLine() << "Attributes { ";
120         llvm::interleaveComma(attributes, os, [&](const Attribute &attr) {
121           os << attr.getName() << " : ";
122 
123           auto kind = attr.isOptional() ? VariableLengthKind::Optional
124                                         : VariableLengthKind::Single;
125           printVariableLengthCst(attr.getConstraint().getDemangledName(), kind);
126         });
127         os << " }\n";
128       }
129 
130       // Operands.
131       ArrayRef<OperandOrResult> operands = op->getOperands();
132       if (!operands.empty()) {
133         printer.startLine() << "Operands { ";
134         llvm::interleaveComma(
135             operands, os, [&](const OperandOrResult &operand) {
136               os << operand.getName() << " : ";
137               printVariableLengthCst(operand.getConstraint().getDemangledName(),
138                                      operand.getVariableLengthKind());
139             });
140         os << " }\n";
141       }
142 
143       // Results.
144       ArrayRef<OperandOrResult> results = op->getResults();
145       if (!results.empty()) {
146         printer.startLine() << "Results { ";
147         llvm::interleaveComma(results, os, [&](const OperandOrResult &result) {
148           os << result.getName() << " : ";
149           printVariableLengthCst(result.getConstraint().getDemangledName(),
150                                  result.getVariableLengthKind());
151         });
152         os << " }\n";
153       }
154 
155       printer.objectEnd();
156     }
157     printer.objectEnd();
158   }
159   for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) {
160     printer.startLine() << "AttributeConstraint `" << cst->getDemangledName()
161                         << "` {\n";
162     printer.indent();
163 
164     printer.startLine() << "Summary: " << cst->getSummary() << "\n";
165     printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
166     printer.objectEnd();
167   }
168   for (const TypeConstraint *cst : sortMapByName(typeConstraints)) {
169     printer.startLine() << "TypeConstraint `" << cst->getDemangledName()
170                         << "` {\n";
171     printer.indent();
172 
173     printer.startLine() << "Summary: " << cst->getSummary() << "\n";
174     printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
175     printer.objectEnd();
176   }
177   printer.objectEnd();
178 }
179