xref: /llvm-project/mlir/lib/TableGen/CodeGenHelpers.cpp (revision 659192b1843c4af180700783caca4cdc7afa3eab)
1 //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/CodeGenHelpers.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/Path.h"
19 #include "llvm/TableGen/Record.h"
20 
21 using namespace llvm;
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 
25 /// Generate a unique label based on the current file name to prevent name
26 /// collisions if multiple generated files are included at once.
27 static std::string getUniqueOutputLabel(const RecordKeeper &records,
28                                         StringRef tag) {
29   // Use the input file name when generating a unique name.
30   std::string inputFilename = records.getInputFilename();
31 
32   // Drop all but the base filename.
33   StringRef nameRef = sys::path::filename(inputFilename);
34   nameRef.consume_back(".td");
35 
36   // Sanitize any invalid characters.
37   std::string uniqueName(tag);
38   for (char c : nameRef) {
39     if (isAlnum(c) || c == '_')
40       uniqueName.push_back(c);
41     else
42       uniqueName.append(utohexstr((unsigned char)c));
43   }
44   return uniqueName;
45 }
46 
47 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
48     raw_ostream &os, const RecordKeeper &records, StringRef tag)
49     : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
50 
51 void StaticVerifierFunctionEmitter::emitOpConstraints(
52     ArrayRef<const Record *> opDefs) {
53   NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
54   emitTypeConstraints();
55   emitAttrConstraints();
56   emitSuccessorConstraints();
57   emitRegionConstraints();
58 }
59 
60 void StaticVerifierFunctionEmitter::emitPatternConstraints(
61     const ArrayRef<DagLeaf> constraints) {
62   collectPatternConstraints(constraints);
63   emitPatternConstraints();
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // Constraint Getters
68 
69 StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
70     const Constraint &constraint) const {
71   const auto *it = typeConstraints.find(constraint);
72   assert(it != typeConstraints.end() && "expected to find a type constraint");
73   return it->second;
74 }
75 
76 // Find a uniqued attribute constraint. Since not all attribute constraints can
77 // be uniqued, return std::nullopt if one was not found.
78 std::optional<StringRef> StaticVerifierFunctionEmitter::getAttrConstraintFn(
79     const Constraint &constraint) const {
80   const auto *it = attrConstraints.find(constraint);
81   return it == attrConstraints.end() ? std::optional<StringRef>()
82                                      : StringRef(it->second);
83 }
84 
85 StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn(
86     const Constraint &constraint) const {
87   const auto *it = successorConstraints.find(constraint);
88   assert(it != successorConstraints.end() &&
89          "expected to find a sucessor constraint");
90   return it->second;
91 }
92 
93 StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn(
94     const Constraint &constraint) const {
95   const auto *it = regionConstraints.find(constraint);
96   assert(it != regionConstraints.end() &&
97          "expected to find a region constraint");
98   return it->second;
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Constraint Emission
103 
104 /// Code templates for emitting type, attribute, successor, and region
105 /// constraints. Each of these templates require the following arguments:
106 ///
107 /// {0}: The unique constraint name.
108 /// {1}: The constraint code.
109 /// {2}: The constraint description.
110 
111 /// Code for a type constraint. These may be called on the type of either
112 /// operands or results.
113 static const char *const typeConstraintCode = R"(
114 static ::llvm::LogicalResult {0}(
115     ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind,
116     unsigned valueIndex) {
117   if (!({1})) {
118     return op->emitOpError(valueKind) << " #" << valueIndex
119         << " must be {2}, but got " << type;
120   }
121   return ::mlir::success();
122 }
123 )";
124 
125 /// Code for an attribute constraint. These may be called from ops only.
126 /// Attribute constraints cannot reference anything other than `$_self` and
127 /// `$_op`.
128 ///
129 /// TODO: Unique constraints for adaptors. However, most Adaptor::verify
130 /// functions are stripped anyways.
131 static const char *const attrConstraintCode = R"(
132 static ::llvm::LogicalResult {0}(
133     ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
134   if (attr && !({1}))
135     return emitError() << "attribute '" << attrName
136         << "' failed to satisfy constraint: {2}";
137   return ::mlir::success();
138 }
139 static ::llvm::LogicalResult {0}(
140     ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {{
141   return {0}(attr, attrName, [op]() {{
142     return op->emitOpError();
143   });
144 }
145 )";
146 
147 /// Code for a successor constraint.
148 static const char *const successorConstraintCode = R"(
149 static ::llvm::LogicalResult {0}(
150     ::mlir::Operation *op, ::mlir::Block *successor,
151     ::llvm::StringRef successorName, unsigned successorIndex) {
152   if (!({1})) {
153     return op->emitOpError("successor #") << successorIndex << " ('"
154         << successorName << ")' failed to verify constraint: {2}";
155   }
156   return ::mlir::success();
157 }
158 )";
159 
160 /// Code for a region constraint. Callers will need to pass in the region's name
161 /// for emitting an error message.
162 static const char *const regionConstraintCode = R"(
163 static ::llvm::LogicalResult {0}(
164     ::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName,
165     unsigned regionIndex) {
166   if (!({1})) {
167     return op->emitOpError("region #") << regionIndex
168         << (regionName.empty() ? " " : " ('" + regionName + "') ")
169         << "failed to verify constraint: {2}";
170   }
171   return ::mlir::success();
172 }
173 )";
174 
175 /// Code for a pattern type or attribute constraint.
176 ///
177 /// {3}: "Type type" or "Attribute attr".
178 static const char *const patternAttrOrTypeConstraintCode = R"(
179 static ::llvm::LogicalResult {0}(
180     ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3},
181     ::llvm::StringRef failureStr) {
182   if (!({1})) {
183     return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
184       diag << failureStr << ": {2}";
185     });
186   }
187   return ::mlir::success();
188 }
189 )";
190 
191 void StaticVerifierFunctionEmitter::emitConstraints(
192     const ConstraintMap &constraints, StringRef selfName,
193     const char *const codeTemplate) {
194   FmtContext ctx;
195   ctx.addSubst("_op", "*op").withSelf(selfName);
196   for (auto &it : constraints) {
197     os << formatv(codeTemplate, it.second,
198                   tgfmt(it.first.getConditionTemplate(), &ctx),
199                   escapeString(it.first.getSummary()));
200   }
201 }
202 
203 void StaticVerifierFunctionEmitter::emitTypeConstraints() {
204   emitConstraints(typeConstraints, "type", typeConstraintCode);
205 }
206 
207 void StaticVerifierFunctionEmitter::emitAttrConstraints() {
208   emitConstraints(attrConstraints, "attr", attrConstraintCode);
209 }
210 
211 void StaticVerifierFunctionEmitter::emitSuccessorConstraints() {
212   emitConstraints(successorConstraints, "successor", successorConstraintCode);
213 }
214 
215 void StaticVerifierFunctionEmitter::emitRegionConstraints() {
216   emitConstraints(regionConstraints, "region", regionConstraintCode);
217 }
218 
219 void StaticVerifierFunctionEmitter::emitPatternConstraints() {
220   FmtContext ctx;
221   ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type");
222   for (auto &it : typeConstraints) {
223     os << formatv(patternAttrOrTypeConstraintCode, it.second,
224                   tgfmt(it.first.getConditionTemplate(), &ctx),
225                   escapeString(it.first.getSummary()), "Type type");
226   }
227   ctx.withSelf("attr");
228   for (auto &it : attrConstraints) {
229     os << formatv(patternAttrOrTypeConstraintCode, it.second,
230                   tgfmt(it.first.getConditionTemplate(), &ctx),
231                   escapeString(it.first.getSummary()), "Attribute attr");
232   }
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // Constraint Uniquing
237 
238 /// An attribute constraint that references anything other than itself and the
239 /// current op cannot be generically extracted into a function. Most
240 /// prohibitive are operands and results, which require calls to
241 /// `getODSOperands` or `getODSResults`. Attribute references are tricky too
242 /// because ops use cached identifiers.
243 static bool canUniqueAttrConstraint(Attribute attr) {
244   FmtContext ctx;
245   auto test = tgfmt(attr.getConditionTemplate(),
246                     &ctx.withSelf("attr").addSubst("_op", "*op"))
247                   .str();
248   return !StringRef(test).contains("<no-subst-found>");
249 }
250 
251 std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind,
252                                                          unsigned index) {
253   return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel +
254           Twine(index))
255       .str();
256 }
257 
258 void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map,
259                                                       StringRef kind,
260                                                       Constraint constraint) {
261   auto [it, inserted] = map.try_emplace(constraint);
262   if (inserted)
263     it->second = getUniqueName(kind, map.size());
264 }
265 
266 void StaticVerifierFunctionEmitter::collectOpConstraints(
267     ArrayRef<const Record *> opDefs) {
268   const auto collectTypeConstraints = [&](Operator::const_value_range values) {
269     for (const NamedTypeConstraint &value : values)
270       if (value.hasPredicate())
271         collectConstraint(typeConstraints, "type", value.constraint);
272   };
273 
274   for (const Record *def : opDefs) {
275     Operator op(*def);
276     /// Collect type constraints.
277     collectTypeConstraints(op.getOperands());
278     collectTypeConstraints(op.getResults());
279     /// Collect attribute constraints.
280     for (const NamedAttribute &namedAttr : op.getAttributes()) {
281       if (!namedAttr.attr.getPredicate().isNull() &&
282           !namedAttr.attr.isDerivedAttr() &&
283           canUniqueAttrConstraint(namedAttr.attr))
284         collectConstraint(attrConstraints, "attr", namedAttr.attr);
285     }
286     /// Collect successor constraints.
287     for (const NamedSuccessor &successor : op.getSuccessors()) {
288       if (!successor.constraint.getPredicate().isNull()) {
289         collectConstraint(successorConstraints, "successor",
290                           successor.constraint);
291       }
292     }
293     /// Collect region constraints.
294     for (const NamedRegion &region : op.getRegions())
295       if (!region.constraint.getPredicate().isNull())
296         collectConstraint(regionConstraints, "region", region.constraint);
297   }
298 }
299 
300 void StaticVerifierFunctionEmitter::collectPatternConstraints(
301     const ArrayRef<DagLeaf> constraints) {
302   for (auto &leaf : constraints) {
303     assert(leaf.isOperandMatcher() || leaf.isAttrMatcher());
304     collectConstraint(
305         leaf.isOperandMatcher() ? typeConstraints : attrConstraints,
306         leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint());
307   }
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // Public Utility Functions
312 //===----------------------------------------------------------------------===//
313 
314 std::string mlir::tblgen::escapeString(StringRef value) {
315   std::string ret;
316   raw_string_ostream os(ret);
317   os.write_escaped(value);
318   return ret;
319 }
320