1 //===- OmpOpGen.cpp - OpenMP dialect op specific generators ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OmpOpGen defines OpenMP dialect operation specific generators. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/GenInfo.h" 14 15 #include "mlir/TableGen/CodeGenHelpers.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/ADT/StringSet.h" 18 #include "llvm/ADT/TypeSwitch.h" 19 #include "llvm/Support/FormatAdapters.h" 20 #include "llvm/TableGen/Error.h" 21 #include "llvm/TableGen/Record.h" 22 23 using namespace llvm; 24 25 /// The code block defining the base mixin class for combining clause operand 26 /// structures. 27 static const char *const baseMixinClass = R"( 28 namespace detail { 29 template <typename... Mixins> 30 struct Clauses : public Mixins... {}; 31 } // namespace detail 32 )"; 33 34 /// The code block defining operation argument structures. 35 static const char *const operationArgStruct = R"( 36 using {0}Operands = detail::Clauses<{1}>; 37 )"; 38 39 /// Remove multiple optional prefixes and suffixes from \c str. 40 /// 41 /// Prefixes and suffixes are attempted to be removed once in the order they 42 /// appear in the \c prefixes and \c suffixes arguments. All prefixes are 43 /// processed before suffixes are. This means it will behave as shown in the 44 /// following example: 45 /// - str: "PrePreNameSuf1Suf2" 46 /// - prefixes: ["Pre"] 47 /// - suffixes: ["Suf1", "Suf2"] 48 /// - return: "PreNameSuf1" 49 static StringRef stripPrefixAndSuffix(StringRef str, 50 llvm::ArrayRef<StringRef> prefixes, 51 llvm::ArrayRef<StringRef> suffixes) { 52 for (StringRef prefix : prefixes) 53 if (str.starts_with(prefix)) 54 str = str.drop_front(prefix.size()); 55 56 for (StringRef suffix : suffixes) 57 if (str.ends_with(suffix)) 58 str = str.drop_back(suffix.size()); 59 60 return str; 61 } 62 63 /// Obtain the name of the OpenMP clause a given record inheriting 64 /// `OpenMP_Clause` refers to. 65 /// 66 /// It supports direct and indirect `OpenMP_Clause` superclasses. Once the 67 /// `OpenMP_Clause` class the record is based on is found, the optional 68 /// "OpenMP_" prefix and "Skip" and "Clause" suffixes are removed to return only 69 /// the clause name, i.e. "OpenMP_CollapseClauseSkip" is returned as "Collapse". 70 static StringRef extractOmpClauseName(const Record *clause) { 71 const Record *ompClause = clause->getRecords().getClass("OpenMP_Clause"); 72 assert(ompClause && "base OpenMP records expected to be defined"); 73 74 StringRef clauseClassName; 75 SmallVector<const Record *, 1> clauseSuperClasses; 76 clause->getDirectSuperClasses(clauseSuperClasses); 77 78 // Check if OpenMP_Clause is a direct superclass. 79 for (const Record *superClass : clauseSuperClasses) { 80 if (superClass == ompClause) { 81 clauseClassName = clause->getName(); 82 break; 83 } 84 } 85 86 // Support indirectly-inherited OpenMP_Clauses. 87 if (clauseClassName.empty()) { 88 for (auto [superClass, _] : clause->getSuperClasses()) { 89 if (superClass->isSubClassOf(ompClause)) { 90 clauseClassName = superClass->getName(); 91 break; 92 } 93 } 94 } 95 96 assert(!clauseClassName.empty() && "clause name must be found"); 97 98 // Keep only the OpenMP clause name itself for reporting purposes. 99 return stripPrefixAndSuffix(clauseClassName, /*prefixes=*/{"OpenMP_"}, 100 /*suffixes=*/{"Skip", "Clause"}); 101 } 102 103 /// Check that the given argument, identified by its name and initialization 104 /// value, is present in the \c arguments `dag`. 105 static bool verifyArgument(const DagInit *arguments, StringRef argName, 106 const Init *argInit) { 107 auto range = zip_equal(arguments->getArgNames(), arguments->getArgs()); 108 return llvm::any_of( 109 range, [&](std::tuple<const llvm::StringInit *, const llvm::Init *> v) { 110 return std::get<0>(v)->getAsUnquotedString() == argName && 111 std::get<1>(v) == argInit; 112 }); 113 } 114 115 /// Check that the given string record value, identified by its \c opValueName, 116 /// is either undefined or empty in both the given operation and clause record 117 /// or its contents for the clause record are contained in the operation record. 118 /// Passing a non-empty \c clauseValueName enables checking values named 119 /// differently in the operation and clause records. 120 static bool verifyStringValue(const Record *op, const Record *clause, 121 StringRef opValueName, 122 StringRef clauseValueName = {}) { 123 auto opValue = op->getValueAsOptionalString(opValueName); 124 auto clauseValue = clause->getValueAsOptionalString( 125 clauseValueName.empty() ? opValueName : clauseValueName); 126 127 bool opHasValue = opValue && !opValue->trim().empty(); 128 bool clauseHasValue = clauseValue && !clauseValue->trim().empty(); 129 130 if (!opHasValue) 131 return !clauseHasValue; 132 133 return !clauseHasValue || opValue->contains(clauseValue->trim()); 134 } 135 136 /// Verify that all fields of the given clause not explicitly ignored are 137 /// present in the corresponding operation field. 138 /// 139 /// Print warnings or errors where this is not the case. 140 static void verifyClause(const Record *op, const Record *clause) { 141 StringRef clauseClassName = extractOmpClauseName(clause); 142 143 if (!clause->getValueAsBit("ignoreArgs")) { 144 const DagInit *opArguments = op->getValueAsDag("arguments"); 145 const DagInit *arguments = clause->getValueAsDag("arguments"); 146 147 for (auto [name, arg] : 148 zip(arguments->getArgNames(), arguments->getArgs())) { 149 if (!verifyArgument(opArguments, name->getAsUnquotedString(), arg)) 150 PrintWarning( 151 op->getLoc(), 152 "'" + clauseClassName + "' clause-defined argument '" + 153 arg->getAsUnquotedString() + ":$" + 154 name->getAsUnquotedString() + 155 "' not present in operation. Consider `dag arguments = " 156 "!con(clausesArgs, ...)` or explicitly skipping this field."); 157 } 158 } 159 160 if (!clause->getValueAsBit("ignoreAsmFormat") && 161 !verifyStringValue(op, clause, "assemblyFormat", "reqAssemblyFormat")) 162 PrintWarning( 163 op->getLoc(), 164 "'" + clauseClassName + 165 "' clause-defined `reqAssemblyFormat` not present in operation. " 166 "Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or " 167 "explicitly skipping this field."); 168 169 if (!clause->getValueAsBit("ignoreAsmFormat") && 170 !verifyStringValue(op, clause, "assemblyFormat", "optAssemblyFormat")) 171 PrintWarning( 172 op->getLoc(), 173 "'" + clauseClassName + 174 "' clause-defined `optAssemblyFormat` not present in operation. " 175 "Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or " 176 "explicitly skipping this field."); 177 178 if (!clause->getValueAsBit("ignoreDesc") && 179 !verifyStringValue(op, clause, "description")) 180 PrintError(op->getLoc(), 181 "'" + clauseClassName + 182 "' clause-defined `description` not present in operation. " 183 "Consider concatenating `clausesDescription` or explicitly " 184 "skipping this field."); 185 186 if (!clause->getValueAsBit("ignoreExtraDecl") && 187 !verifyStringValue(op, clause, "extraClassDeclaration")) 188 PrintWarning( 189 op->getLoc(), 190 "'" + clauseClassName + 191 "' clause-defined `extraClassDeclaration` not present in " 192 "operation. Consider concatenating `clausesExtraClassDeclaration` " 193 "or explicitly skipping this field."); 194 } 195 196 /// Translate the type of an OpenMP clause's argument to its corresponding 197 /// representation for clause operand structures. 198 /// 199 /// All kinds of values are represented as `mlir::Value` fields, whereas 200 /// attributes are represented based on their `storageType`. 201 /// 202 /// \param[in] name The name of the argument. 203 /// \param[in] init The `DefInit` object representing the argument. 204 /// \param[out] nest Number of levels of array nesting associated with the 205 /// type. Must be initially set to 0. 206 /// \param[out] rank Rank (number of dimensions, if an array type) of the base 207 /// type. Must be initially set to 1. 208 /// 209 /// \return the name of the base type to represent elements of the argument 210 /// type. 211 static StringRef translateArgumentType(ArrayRef<SMLoc> loc, 212 const StringInit *name, const Init *init, 213 int &nest, int &rank) { 214 const Record *def = cast<DefInit>(init)->getDef(); 215 216 llvm::StringSet<> superClasses; 217 for (auto [sc, _] : def->getSuperClasses()) 218 superClasses.insert(sc->getNameInitAsString()); 219 220 // Handle wrapper-style superclasses. 221 if (superClasses.contains("OptionalAttr")) 222 return translateArgumentType( 223 loc, name, def->getValue("baseAttr")->getValue(), nest, rank); 224 225 if (superClasses.contains("TypedArrayAttrBase")) 226 return translateArgumentType( 227 loc, name, def->getValue("elementAttr")->getValue(), ++nest, rank); 228 229 // Handle ElementsAttrBase superclasses. 230 if (superClasses.contains("ElementsAttrBase")) { 231 // TODO: Obtain the rank from ranked types. 232 ++nest; 233 234 if (superClasses.contains("IntElementsAttrBase")) 235 return "::llvm::APInt"; 236 if (superClasses.contains("FloatElementsAttr") || 237 superClasses.contains("RankedFloatElementsAttr")) 238 return "::llvm::APFloat"; 239 if (superClasses.contains("DenseArrayAttrBase")) 240 return stripPrefixAndSuffix(def->getValueAsString("returnType"), 241 {"::llvm::ArrayRef<"}, {">"}); 242 243 // Decrease the nesting depth in the case where the base type cannot be 244 // inferred, so that the bare storageType is used instead of a vector. 245 --nest; 246 PrintWarning( 247 loc, 248 "could not infer array-like attribute element type for argument '" + 249 name->getAsUnquotedString() + "', will use bare `storageType`"); 250 } 251 252 // Handle simple attribute and value types. 253 [[maybe_unused]] bool isAttr = superClasses.contains("Attr"); 254 bool isValue = superClasses.contains("TypeConstraint"); 255 if (superClasses.contains("Variadic")) 256 ++nest; 257 258 if (isValue) { 259 assert(!isAttr && 260 "argument can't be simultaneously a value and an attribute"); 261 return "::mlir::Value"; 262 } 263 264 assert(isAttr && "argument must be an attribute if it's not a value"); 265 return nest > 0 ? "::mlir::Attribute" 266 : def->getValueAsString("storageType").trim(); 267 } 268 269 /// Generate the structure that represents the arguments of the given \c clause 270 /// record of type \c OpenMP_Clause. 271 /// 272 /// It will contain a field for each argument, using the same name translated to 273 /// camel case and the corresponding base type as returned by 274 /// translateArgumentType() optionally wrapped in one or more llvm::SmallVector. 275 /// 276 /// An additional field containing a tuple of integers to hold the size of each 277 /// dimension will also be created for multi-rank types. This is not yet 278 /// supported. 279 static void genClauseOpsStruct(const Record *clause, raw_ostream &os) { 280 if (clause->isAnonymous()) 281 return; 282 283 StringRef clauseName = extractOmpClauseName(clause); 284 os << "struct " << clauseName << "ClauseOps {\n"; 285 286 const DagInit *arguments = clause->getValueAsDag("arguments"); 287 for (auto [name, arg] : 288 zip_equal(arguments->getArgNames(), arguments->getArgs())) { 289 int nest = 0, rank = 1; 290 StringRef baseType = 291 translateArgumentType(clause->getLoc(), name, arg, nest, rank); 292 std::string fieldName = 293 convertToCamelFromSnakeCase(name->getAsUnquotedString(), 294 /*capitalizeFirst=*/false); 295 296 os << formatv(" {0}{1}{2} {3};\n", 297 fmt_repeat("::llvm::SmallVector<", nest), baseType, 298 fmt_repeat(">", nest), fieldName); 299 300 if (rank > 1) { 301 assert(nest >= 1 && "must be nested if it's a ranked type"); 302 os << formatv(" {0}::std::tuple<{1}int>{2} {3}Dims;\n", 303 fmt_repeat("::llvm::SmallVector<", nest - 1), 304 fmt_repeat("int, ", rank - 1), fmt_repeat(">", nest - 1), 305 fieldName); 306 } 307 } 308 309 os << "};\n"; 310 } 311 312 /// Generate the structure that represents the clause-related arguments of the 313 /// given \c op record of type \c OpenMP_Op. 314 /// 315 /// This structure will be defined in terms of the clause operand structures 316 /// associated to the clauses of the operation. 317 static void genOperandsDef(const Record *op, raw_ostream &os) { 318 if (op->isAnonymous()) 319 return; 320 321 SmallVector<std::string> clauseNames; 322 for (const Record *clause : op->getValueAsListOfDefs("clauseList")) 323 clauseNames.push_back((extractOmpClauseName(clause) + "ClauseOps").str()); 324 325 StringRef opName = stripPrefixAndSuffix( 326 op->getName(), /*prefixes=*/{"OpenMP_"}, /*suffixes=*/{"Op"}); 327 os << formatv(operationArgStruct, opName, join(clauseNames, ", ")); 328 } 329 330 /// Verify that all properties of `OpenMP_Clause`s of records deriving from 331 /// `OpenMP_Op`s have been inherited by the latter. 332 static bool verifyDecls(const RecordKeeper &records, raw_ostream &) { 333 for (const Record *op : records.getAllDerivedDefinitions("OpenMP_Op")) { 334 for (const Record *clause : op->getValueAsListOfDefs("clauseList")) 335 verifyClause(op, clause); 336 } 337 338 return false; 339 } 340 341 /// Generate structures to represent clause-related operands, based on existing 342 /// `OpenMP_Clause` definitions and aggregate them into operation-specific 343 /// structures according to the `clauses` argument of each definition deriving 344 /// from `OpenMP_Op`. 345 static bool genClauseOps(const RecordKeeper &records, raw_ostream &os) { 346 mlir::tblgen::NamespaceEmitter ns(os, "mlir::omp"); 347 for (const Record *clause : records.getAllDerivedDefinitions("OpenMP_Clause")) 348 genClauseOpsStruct(clause, os); 349 350 // Produce base mixin class. 351 os << baseMixinClass; 352 353 for (const Record *op : records.getAllDerivedDefinitions("OpenMP_Op")) 354 genOperandsDef(op, os); 355 356 return false; 357 } 358 359 // Registers the generator to mlir-tblgen. 360 static mlir::GenRegistration 361 verifyOpenmpOps("verify-openmp-ops", 362 "Verify OpenMP operations (produce no output file)", 363 verifyDecls); 364 365 static mlir::GenRegistration 366 genOpenmpClauseOps("gen-openmp-clause-ops", 367 "Generate OpenMP clause operand structures", 368 genClauseOps); 369