xref: /llvm-project/mlir/tools/mlir-tblgen/OmpOpGen.cpp (revision e768b076e3b7ed38485a29244a0b989076e4b131)
1 //===- OmpOpGen.cpp - OpenMP dialect op specific generators ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OmpOpGen defines OpenMP dialect operation specific generators.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/GenInfo.h"
14 
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/StringSet.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/FormatAdapters.h"
20 #include "llvm/TableGen/Error.h"
21 #include "llvm/TableGen/Record.h"
22 
23 using namespace llvm;
24 
25 /// The code block defining the base mixin class for combining clause operand
26 /// structures.
27 static const char *const baseMixinClass = R"(
28 namespace detail {
29 template <typename... Mixins>
30 struct Clauses : public Mixins... {};
31 } // namespace detail
32 )";
33 
34 /// The code block defining operation argument structures.
35 static const char *const operationArgStruct = R"(
36 using {0}Operands = detail::Clauses<{1}>;
37 )";
38 
39 /// Remove multiple optional prefixes and suffixes from \c str.
40 ///
41 /// Prefixes and suffixes are attempted to be removed once in the order they
42 /// appear in the \c prefixes and \c suffixes arguments. All prefixes are
43 /// processed before suffixes are. This means it will behave as shown in the
44 /// following example:
45 ///   - str: "PrePreNameSuf1Suf2"
46 ///   - prefixes: ["Pre"]
47 ///   - suffixes: ["Suf1", "Suf2"]
48 ///   - return: "PreNameSuf1"
49 static StringRef stripPrefixAndSuffix(StringRef str,
50                                       llvm::ArrayRef<StringRef> prefixes,
51                                       llvm::ArrayRef<StringRef> suffixes) {
52   for (StringRef prefix : prefixes)
53     if (str.starts_with(prefix))
54       str = str.drop_front(prefix.size());
55 
56   for (StringRef suffix : suffixes)
57     if (str.ends_with(suffix))
58       str = str.drop_back(suffix.size());
59 
60   return str;
61 }
62 
63 /// Obtain the name of the OpenMP clause a given record inheriting
64 /// `OpenMP_Clause` refers to.
65 ///
66 /// It supports direct and indirect `OpenMP_Clause` superclasses. Once the
67 /// `OpenMP_Clause` class the record is based on is found, the optional
68 /// "OpenMP_" prefix and "Skip" and "Clause" suffixes are removed to return only
69 /// the clause name, i.e. "OpenMP_CollapseClauseSkip" is returned as "Collapse".
70 static StringRef extractOmpClauseName(const Record *clause) {
71   const Record *ompClause = clause->getRecords().getClass("OpenMP_Clause");
72   assert(ompClause && "base OpenMP records expected to be defined");
73 
74   StringRef clauseClassName;
75   SmallVector<const Record *, 1> clauseSuperClasses;
76   clause->getDirectSuperClasses(clauseSuperClasses);
77 
78   // Check if OpenMP_Clause is a direct superclass.
79   for (const Record *superClass : clauseSuperClasses) {
80     if (superClass == ompClause) {
81       clauseClassName = clause->getName();
82       break;
83     }
84   }
85 
86   // Support indirectly-inherited OpenMP_Clauses.
87   if (clauseClassName.empty()) {
88     for (auto [superClass, _] : clause->getSuperClasses()) {
89       if (superClass->isSubClassOf(ompClause)) {
90         clauseClassName = superClass->getName();
91         break;
92       }
93     }
94   }
95 
96   assert(!clauseClassName.empty() && "clause name must be found");
97 
98   // Keep only the OpenMP clause name itself for reporting purposes.
99   return stripPrefixAndSuffix(clauseClassName, /*prefixes=*/{"OpenMP_"},
100                               /*suffixes=*/{"Skip", "Clause"});
101 }
102 
103 /// Check that the given argument, identified by its name and initialization
104 /// value, is present in the \c arguments `dag`.
105 static bool verifyArgument(const DagInit *arguments, StringRef argName,
106                            const Init *argInit) {
107   auto range = zip_equal(arguments->getArgNames(), arguments->getArgs());
108   return llvm::any_of(
109       range, [&](std::tuple<const llvm::StringInit *, const llvm::Init *> v) {
110         return std::get<0>(v)->getAsUnquotedString() == argName &&
111                std::get<1>(v) == argInit;
112       });
113 }
114 
115 /// Check that the given string record value, identified by its \c opValueName,
116 /// is either undefined or empty in both the given operation and clause record
117 /// or its contents for the clause record are contained in the operation record.
118 /// Passing a non-empty \c clauseValueName enables checking values named
119 /// differently in the operation and clause records.
120 static bool verifyStringValue(const Record *op, const Record *clause,
121                               StringRef opValueName,
122                               StringRef clauseValueName = {}) {
123   auto opValue = op->getValueAsOptionalString(opValueName);
124   auto clauseValue = clause->getValueAsOptionalString(
125       clauseValueName.empty() ? opValueName : clauseValueName);
126 
127   bool opHasValue = opValue && !opValue->trim().empty();
128   bool clauseHasValue = clauseValue && !clauseValue->trim().empty();
129 
130   if (!opHasValue)
131     return !clauseHasValue;
132 
133   return !clauseHasValue || opValue->contains(clauseValue->trim());
134 }
135 
136 /// Verify that all fields of the given clause not explicitly ignored are
137 /// present in the corresponding operation field.
138 ///
139 /// Print warnings or errors where this is not the case.
140 static void verifyClause(const Record *op, const Record *clause) {
141   StringRef clauseClassName = extractOmpClauseName(clause);
142 
143   if (!clause->getValueAsBit("ignoreArgs")) {
144     const DagInit *opArguments = op->getValueAsDag("arguments");
145     const DagInit *arguments = clause->getValueAsDag("arguments");
146 
147     for (auto [name, arg] :
148          zip(arguments->getArgNames(), arguments->getArgs())) {
149       if (!verifyArgument(opArguments, name->getAsUnquotedString(), arg))
150         PrintWarning(
151             op->getLoc(),
152             "'" + clauseClassName + "' clause-defined argument '" +
153                 arg->getAsUnquotedString() + ":$" +
154                 name->getAsUnquotedString() +
155                 "' not present in operation. Consider `dag arguments = "
156                 "!con(clausesArgs, ...)` or explicitly skipping this field.");
157     }
158   }
159 
160   if (!clause->getValueAsBit("ignoreAsmFormat") &&
161       !verifyStringValue(op, clause, "assemblyFormat", "reqAssemblyFormat"))
162     PrintWarning(
163         op->getLoc(),
164         "'" + clauseClassName +
165             "' clause-defined `reqAssemblyFormat` not present in operation. "
166             "Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or "
167             "explicitly skipping this field.");
168 
169   if (!clause->getValueAsBit("ignoreAsmFormat") &&
170       !verifyStringValue(op, clause, "assemblyFormat", "optAssemblyFormat"))
171     PrintWarning(
172         op->getLoc(),
173         "'" + clauseClassName +
174             "' clause-defined `optAssemblyFormat` not present in operation. "
175             "Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or "
176             "explicitly skipping this field.");
177 
178   if (!clause->getValueAsBit("ignoreDesc") &&
179       !verifyStringValue(op, clause, "description"))
180     PrintError(op->getLoc(),
181                "'" + clauseClassName +
182                    "' clause-defined `description` not present in operation. "
183                    "Consider concatenating `clausesDescription` or explicitly "
184                    "skipping this field.");
185 
186   if (!clause->getValueAsBit("ignoreExtraDecl") &&
187       !verifyStringValue(op, clause, "extraClassDeclaration"))
188     PrintWarning(
189         op->getLoc(),
190         "'" + clauseClassName +
191             "' clause-defined `extraClassDeclaration` not present in "
192             "operation. Consider concatenating `clausesExtraClassDeclaration` "
193             "or explicitly skipping this field.");
194 }
195 
196 /// Translate the type of an OpenMP clause's argument to its corresponding
197 /// representation for clause operand structures.
198 ///
199 /// All kinds of values are represented as `mlir::Value` fields, whereas
200 /// attributes are represented based on their `storageType`.
201 ///
202 /// \param[in] name The name of the argument.
203 /// \param[in] init The `DefInit` object representing the argument.
204 /// \param[out] nest Number of levels of array nesting associated with the
205 ///                  type. Must be initially set to 0.
206 /// \param[out] rank Rank (number of dimensions, if an array type) of the base
207 ///                  type. Must be initially set to 1.
208 ///
209 /// \return the name of the base type to represent elements of the argument
210 ///         type.
211 static StringRef translateArgumentType(ArrayRef<SMLoc> loc,
212                                        const StringInit *name, const Init *init,
213                                        int &nest, int &rank) {
214   const Record *def = cast<DefInit>(init)->getDef();
215 
216   llvm::StringSet<> superClasses;
217   for (auto [sc, _] : def->getSuperClasses())
218     superClasses.insert(sc->getNameInitAsString());
219 
220   // Handle wrapper-style superclasses.
221   if (superClasses.contains("OptionalAttr"))
222     return translateArgumentType(
223         loc, name, def->getValue("baseAttr")->getValue(), nest, rank);
224 
225   if (superClasses.contains("TypedArrayAttrBase"))
226     return translateArgumentType(
227         loc, name, def->getValue("elementAttr")->getValue(), ++nest, rank);
228 
229   // Handle ElementsAttrBase superclasses.
230   if (superClasses.contains("ElementsAttrBase")) {
231     // TODO: Obtain the rank from ranked types.
232     ++nest;
233 
234     if (superClasses.contains("IntElementsAttrBase"))
235       return "::llvm::APInt";
236     if (superClasses.contains("FloatElementsAttr") ||
237         superClasses.contains("RankedFloatElementsAttr"))
238       return "::llvm::APFloat";
239     if (superClasses.contains("DenseArrayAttrBase"))
240       return stripPrefixAndSuffix(def->getValueAsString("returnType"),
241                                   {"::llvm::ArrayRef<"}, {">"});
242 
243     // Decrease the nesting depth in the case where the base type cannot be
244     // inferred, so that the bare storageType is used instead of a vector.
245     --nest;
246     PrintWarning(
247         loc,
248         "could not infer array-like attribute element type for argument '" +
249             name->getAsUnquotedString() + "', will use bare `storageType`");
250   }
251 
252   // Handle simple attribute and value types.
253   [[maybe_unused]] bool isAttr = superClasses.contains("Attr");
254   bool isValue = superClasses.contains("TypeConstraint");
255   if (superClasses.contains("Variadic"))
256     ++nest;
257 
258   if (isValue) {
259     assert(!isAttr &&
260            "argument can't be simultaneously a value and an attribute");
261     return "::mlir::Value";
262   }
263 
264   assert(isAttr && "argument must be an attribute if it's not a value");
265   return nest > 0 ? "::mlir::Attribute"
266                   : def->getValueAsString("storageType").trim();
267 }
268 
269 /// Generate the structure that represents the arguments of the given \c clause
270 /// record of type \c OpenMP_Clause.
271 ///
272 /// It will contain a field for each argument, using the same name translated to
273 /// camel case and the corresponding base type as returned by
274 /// translateArgumentType() optionally wrapped in one or more llvm::SmallVector.
275 ///
276 /// An additional field containing a tuple of integers to hold the size of each
277 /// dimension will also be created for multi-rank types. This is not yet
278 /// supported.
279 static void genClauseOpsStruct(const Record *clause, raw_ostream &os) {
280   if (clause->isAnonymous())
281     return;
282 
283   StringRef clauseName = extractOmpClauseName(clause);
284   os << "struct " << clauseName << "ClauseOps {\n";
285 
286   const DagInit *arguments = clause->getValueAsDag("arguments");
287   for (auto [name, arg] :
288        zip_equal(arguments->getArgNames(), arguments->getArgs())) {
289     int nest = 0, rank = 1;
290     StringRef baseType =
291         translateArgumentType(clause->getLoc(), name, arg, nest, rank);
292     std::string fieldName =
293         convertToCamelFromSnakeCase(name->getAsUnquotedString(),
294                                     /*capitalizeFirst=*/false);
295 
296     os << formatv("  {0}{1}{2} {3};\n",
297                   fmt_repeat("::llvm::SmallVector<", nest), baseType,
298                   fmt_repeat(">", nest), fieldName);
299 
300     if (rank > 1) {
301       assert(nest >= 1 && "must be nested if it's a ranked type");
302       os << formatv("  {0}::std::tuple<{1}int>{2} {3}Dims;\n",
303                     fmt_repeat("::llvm::SmallVector<", nest - 1),
304                     fmt_repeat("int, ", rank - 1), fmt_repeat(">", nest - 1),
305                     fieldName);
306     }
307   }
308 
309   os << "};\n";
310 }
311 
312 /// Generate the structure that represents the clause-related arguments of the
313 /// given \c op record of type \c OpenMP_Op.
314 ///
315 /// This structure will be defined in terms of the clause operand structures
316 /// associated to the clauses of the operation.
317 static void genOperandsDef(const Record *op, raw_ostream &os) {
318   if (op->isAnonymous())
319     return;
320 
321   SmallVector<std::string> clauseNames;
322   for (const Record *clause : op->getValueAsListOfDefs("clauseList"))
323     clauseNames.push_back((extractOmpClauseName(clause) + "ClauseOps").str());
324 
325   StringRef opName = stripPrefixAndSuffix(
326       op->getName(), /*prefixes=*/{"OpenMP_"}, /*suffixes=*/{"Op"});
327   os << formatv(operationArgStruct, opName, join(clauseNames, ", "));
328 }
329 
330 /// Verify that all properties of `OpenMP_Clause`s of records deriving from
331 /// `OpenMP_Op`s have been inherited by the latter.
332 static bool verifyDecls(const RecordKeeper &records, raw_ostream &) {
333   for (const Record *op : records.getAllDerivedDefinitions("OpenMP_Op")) {
334     for (const Record *clause : op->getValueAsListOfDefs("clauseList"))
335       verifyClause(op, clause);
336   }
337 
338   return false;
339 }
340 
341 /// Generate structures to represent clause-related operands, based on existing
342 /// `OpenMP_Clause` definitions and aggregate them into operation-specific
343 /// structures according to the `clauses` argument of each definition deriving
344 /// from `OpenMP_Op`.
345 static bool genClauseOps(const RecordKeeper &records, raw_ostream &os) {
346   mlir::tblgen::NamespaceEmitter ns(os, "mlir::omp");
347   for (const Record *clause : records.getAllDerivedDefinitions("OpenMP_Clause"))
348     genClauseOpsStruct(clause, os);
349 
350   // Produce base mixin class.
351   os << baseMixinClass;
352 
353   for (const Record *op : records.getAllDerivedDefinitions("OpenMP_Op"))
354     genOperandsDef(op, os);
355 
356   return false;
357 }
358 
359 // Registers the generator to mlir-tblgen.
360 static mlir::GenRegistration
361     verifyOpenmpOps("verify-openmp-ops",
362                     "Verify OpenMP operations (produce no output file)",
363                     verifyDecls);
364 
365 static mlir::GenRegistration
366     genOpenmpClauseOps("gen-openmp-clause-ops",
367                        "Generate OpenMP clause operand structures",
368                        genClauseOps);
369