1 //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // DialectGen uses the description of dialects to generate C++ definitions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "DialectGenUtilities.h" 14 #include "mlir/TableGen/Class.h" 15 #include "mlir/TableGen/CodeGenHelpers.h" 16 #include "mlir/TableGen/Format.h" 17 #include "mlir/TableGen/GenInfo.h" 18 #include "mlir/TableGen/Interfaces.h" 19 #include "mlir/TableGen/Operator.h" 20 #include "mlir/TableGen/Trait.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/Support/CommandLine.h" 24 #include "llvm/Support/Signals.h" 25 #include "llvm/TableGen/Error.h" 26 #include "llvm/TableGen/Record.h" 27 #include "llvm/TableGen/TableGenBackend.h" 28 29 #define DEBUG_TYPE "mlir-tblgen-opdefgen" 30 31 using namespace mlir; 32 using namespace mlir::tblgen; 33 using llvm::Record; 34 using llvm::RecordKeeper; 35 36 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*"); 37 llvm::cl::opt<std::string> 38 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"), 39 llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); 40 41 /// Utility iterator used for filtering records for a specific dialect. 42 namespace { 43 using DialectFilterIterator = 44 llvm::filter_iterator<ArrayRef<Record *>::iterator, 45 std::function<bool(const Record *)>>; 46 } // namespace 47 48 static void populateDiscardableAttributes( 49 Dialect &dialect, const llvm::DagInit *discardableAttrDag, 50 SmallVector<std::pair<std::string, std::string>> &discardableAttributes) { 51 for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) { 52 const llvm::Init *arg = discardableAttrDag->getArg(i); 53 54 StringRef givenName = discardableAttrDag->getArgNameStr(i); 55 if (givenName.empty()) 56 PrintFatalError(dialect.getDef()->getLoc(), 57 "discardable attributes must be named"); 58 discardableAttributes.push_back( 59 {givenName.str(), arg->getAsUnquotedString()}); 60 } 61 } 62 63 /// Given a set of records for a T, filter the ones that correspond to 64 /// the given dialect. 65 template <typename T> 66 static iterator_range<DialectFilterIterator> 67 filterForDialect(ArrayRef<Record *> records, Dialect &dialect) { 68 auto filterFn = [&](const Record *record) { 69 return T(record).getDialect() == dialect; 70 }; 71 return {DialectFilterIterator(records.begin(), records.end(), filterFn), 72 DialectFilterIterator(records.end(), records.end(), filterFn)}; 73 } 74 75 std::optional<Dialect> 76 tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) { 77 if (dialects.empty()) { 78 llvm::errs() << "no dialect was found\n"; 79 return std::nullopt; 80 } 81 82 // Select the dialect to gen for. 83 if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0) 84 return dialects.front(); 85 86 if (selectedDialect.getNumOccurrences() == 0) { 87 llvm::errs() << "when more than 1 dialect is present, one must be selected " 88 "via '-dialect'\n"; 89 return std::nullopt; 90 } 91 92 const auto *dialectIt = llvm::find_if(dialects, [](const Dialect &dialect) { 93 return dialect.getName() == selectedDialect; 94 }); 95 if (dialectIt == dialects.end()) { 96 llvm::errs() << "selected dialect with '-dialect' does not exist\n"; 97 return std::nullopt; 98 } 99 return *dialectIt; 100 } 101 102 //===----------------------------------------------------------------------===// 103 // GEN: Dialect declarations 104 //===----------------------------------------------------------------------===// 105 106 /// The code block for the start of a dialect class declaration. 107 /// 108 /// {0}: The name of the dialect class. 109 /// {1}: The dialect namespace. 110 /// {2}: The dialect parent class. 111 static const char *const dialectDeclBeginStr = R"( 112 class {0} : public ::mlir::{2} { 113 explicit {0}(::mlir::MLIRContext *context); 114 115 void initialize(); 116 friend class ::mlir::MLIRContext; 117 public: 118 ~{0}() override; 119 static constexpr ::llvm::StringLiteral getDialectNamespace() { 120 return ::llvm::StringLiteral("{1}"); 121 } 122 )"; 123 124 /// Registration for a single dependent dialect: to be inserted in the ctor 125 /// above for each dependent dialect. 126 const char *const dialectRegistrationTemplate = 127 "getContext()->loadDialect<{0}>();"; 128 129 /// The code block for the attribute parser/printer hooks. 130 static const char *const attrParserDecl = R"( 131 /// Parse an attribute registered to this dialect. 132 ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser, 133 ::mlir::Type type) const override; 134 135 /// Print an attribute registered to this dialect. 136 void printAttribute(::mlir::Attribute attr, 137 ::mlir::DialectAsmPrinter &os) const override; 138 )"; 139 140 /// The code block for the type parser/printer hooks. 141 static const char *const typeParserDecl = R"( 142 /// Parse a type registered to this dialect. 143 ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; 144 145 /// Print a type registered to this dialect. 146 void printType(::mlir::Type type, 147 ::mlir::DialectAsmPrinter &os) const override; 148 )"; 149 150 /// The code block for the canonicalization pattern registration hook. 151 static const char *const canonicalizerDecl = R"( 152 /// Register canonicalization patterns. 153 void getCanonicalizationPatterns( 154 ::mlir::RewritePatternSet &results) const override; 155 )"; 156 157 /// The code block for the constant materializer hook. 158 static const char *const constantMaterializerDecl = R"( 159 /// Materialize a single constant operation from a given attribute value with 160 /// the desired resultant type. 161 ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder, 162 ::mlir::Attribute value, 163 ::mlir::Type type, 164 ::mlir::Location loc) override; 165 )"; 166 167 /// The code block for the operation attribute verifier hook. 168 static const char *const opAttrVerifierDecl = R"( 169 /// Provides a hook for verifying dialect attributes attached to the given 170 /// op. 171 ::llvm::LogicalResult verifyOperationAttribute( 172 ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override; 173 )"; 174 175 /// The code block for the region argument attribute verifier hook. 176 static const char *const regionArgAttrVerifierDecl = R"( 177 /// Provides a hook for verifying dialect attributes attached to the given 178 /// op's region argument. 179 ::llvm::LogicalResult verifyRegionArgAttribute( 180 ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex, 181 ::mlir::NamedAttribute attribute) override; 182 )"; 183 184 /// The code block for the region result attribute verifier hook. 185 static const char *const regionResultAttrVerifierDecl = R"( 186 /// Provides a hook for verifying dialect attributes attached to the given 187 /// op's region result. 188 ::llvm::LogicalResult verifyRegionResultAttribute( 189 ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex, 190 ::mlir::NamedAttribute attribute) override; 191 )"; 192 193 /// The code block for the op interface fallback hook. 194 static const char *const operationInterfaceFallbackDecl = R"( 195 /// Provides a hook for op interface. 196 void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID, 197 mlir::OperationName opName) override; 198 )"; 199 200 /// The code block for the discardable attribute helper. 201 static const char *const discardableAttrHelperDecl = R"( 202 /// Helper to manage the discardable attribute `{1}`. 203 class {0}AttrHelper {{ 204 ::mlir::StringAttr name; 205 public: 206 static constexpr ::llvm::StringLiteral getNameStr() {{ 207 return "{4}.{1}"; 208 } 209 constexpr ::mlir::StringAttr getName() {{ 210 return name; 211 } 212 213 {0}AttrHelper(::mlir::MLIRContext *ctx) 214 : name(::mlir::StringAttr::get(ctx, getNameStr())) {{} 215 216 {2} getAttr(::mlir::Operation *op) {{ 217 return op->getAttrOfType<{2}>(name); 218 } 219 void setAttr(::mlir::Operation *op, {2} val) {{ 220 op->setAttr(name, val); 221 } 222 bool isAttrPresent(::mlir::Operation *op) {{ 223 return op->hasAttrOfType<{2}>(name); 224 } 225 void removeAttr(::mlir::Operation *op) {{ 226 assert(op->hasAttrOfType<{2}>(name)); 227 op->removeAttr(name); 228 } 229 }; 230 {0}AttrHelper get{0}AttrHelper() { 231 return {3}AttrName; 232 } 233 private: 234 {0}AttrHelper {3}AttrName; 235 public: 236 )"; 237 238 /// Generate the declaration for the given dialect class. 239 static void emitDialectDecl(Dialect &dialect, raw_ostream &os) { 240 // Emit all nested namespaces. 241 { 242 NamespaceEmitter nsEmitter(os, dialect); 243 244 // Emit the start of the decl. 245 std::string cppName = dialect.getCppClassName(); 246 StringRef superClassName = 247 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; 248 os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), 249 superClassName); 250 251 // If the dialect requested the default attribute printer and parser, emit 252 // the declarations for the hooks. 253 if (dialect.useDefaultAttributePrinterParser()) 254 os << attrParserDecl; 255 // If the dialect requested the default type printer and parser, emit the 256 // delcarations for the hooks. 257 if (dialect.useDefaultTypePrinterParser()) 258 os << typeParserDecl; 259 260 // Add the decls for the various features of the dialect. 261 if (dialect.hasCanonicalizer()) 262 os << canonicalizerDecl; 263 if (dialect.hasConstantMaterializer()) 264 os << constantMaterializerDecl; 265 if (dialect.hasOperationAttrVerify()) 266 os << opAttrVerifierDecl; 267 if (dialect.hasRegionArgAttrVerify()) 268 os << regionArgAttrVerifierDecl; 269 if (dialect.hasRegionResultAttrVerify()) 270 os << regionResultAttrVerifierDecl; 271 if (dialect.hasOperationInterfaceFallback()) 272 os << operationInterfaceFallbackDecl; 273 274 const llvm::DagInit *discardableAttrDag = 275 dialect.getDiscardableAttributes(); 276 SmallVector<std::pair<std::string, std::string>> discardableAttributes; 277 populateDiscardableAttributes(dialect, discardableAttrDag, 278 discardableAttributes); 279 280 for (const auto &attrPair : discardableAttributes) { 281 std::string camelNameUpper = llvm::convertToCamelFromSnakeCase( 282 attrPair.first, /*capitalizeFirst=*/true); 283 std::string camelName = llvm::convertToCamelFromSnakeCase( 284 attrPair.first, /*capitalizeFirst=*/false); 285 os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper, 286 attrPair.first, attrPair.second, camelName, 287 dialect.getName()); 288 } 289 290 if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration()) 291 os << *extraDecl; 292 293 // End the dialect decl. 294 os << "};\n"; 295 } 296 if (!dialect.getCppNamespace().empty()) 297 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() 298 << "::" << dialect.getCppClassName() << ")\n"; 299 } 300 301 static bool emitDialectDecls(const RecordKeeper &records, raw_ostream &os) { 302 emitSourceFileHeader("Dialect Declarations", os, records); 303 304 auto dialectDefs = records.getAllDerivedDefinitions("Dialect"); 305 if (dialectDefs.empty()) 306 return false; 307 308 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end()); 309 std::optional<Dialect> dialect = findDialectToGenerate(dialects); 310 if (!dialect) 311 return true; 312 emitDialectDecl(*dialect, os); 313 return false; 314 } 315 316 //===----------------------------------------------------------------------===// 317 // GEN: Dialect definitions 318 //===----------------------------------------------------------------------===// 319 320 /// The code block to generate a dialect constructor definition. 321 /// 322 /// {0}: The name of the dialect class. 323 /// {1}: Initialization code that is emitted in the ctor body before calling 324 /// initialize(), such as dependent dialect registration. 325 /// {2}: The dialect parent class. 326 /// {3}: Extra members to initialize 327 static const char *const dialectConstructorStr = R"( 328 {0}::{0}(::mlir::MLIRContext *context) 329 : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) 330 {3} 331 {{ 332 {1} 333 initialize(); 334 } 335 )"; 336 337 /// The code block to generate a default destructor definition. 338 /// 339 /// {0}: The name of the dialect class. 340 static const char *const dialectDestructorStr = R"( 341 {0}::~{0}() = default; 342 343 )"; 344 345 static void emitDialectDef(Dialect &dialect, const RecordKeeper &records, 346 raw_ostream &os) { 347 std::string cppClassName = dialect.getCppClassName(); 348 349 // Emit the TypeID explicit specializations to have a single symbol def. 350 if (!dialect.getCppNamespace().empty()) 351 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() 352 << "::" << cppClassName << ")\n"; 353 354 // Emit all nested namespaces. 355 NamespaceEmitter nsEmitter(os, dialect); 356 357 /// Build the list of dependent dialects. 358 std::string dependentDialectRegistrations; 359 { 360 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); 361 llvm::interleave( 362 dialect.getDependentDialects(), dialectsOs, 363 [&](StringRef dependentDialect) { 364 dialectsOs << llvm::formatv(dialectRegistrationTemplate, 365 dependentDialect); 366 }, 367 "\n "); 368 } 369 370 // Emit the constructor and destructor. 371 StringRef superClassName = 372 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; 373 374 const llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes(); 375 SmallVector<std::pair<std::string, std::string>> discardableAttributes; 376 populateDiscardableAttributes(dialect, discardableAttrDag, 377 discardableAttributes); 378 std::string discardableAttributesInit; 379 for (const auto &attrPair : discardableAttributes) { 380 std::string camelName = llvm::convertToCamelFromSnakeCase( 381 attrPair.first, /*capitalizeFirst=*/false); 382 llvm::raw_string_ostream os(discardableAttributesInit); 383 os << ", " << camelName << "AttrName(context)"; 384 } 385 386 os << llvm::formatv(dialectConstructorStr, cppClassName, 387 dependentDialectRegistrations, superClassName, 388 discardableAttributesInit); 389 if (!dialect.hasNonDefaultDestructor()) 390 os << llvm::formatv(dialectDestructorStr, cppClassName); 391 } 392 393 static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) { 394 emitSourceFileHeader("Dialect Definitions", os, records); 395 396 auto dialectDefs = records.getAllDerivedDefinitions("Dialect"); 397 if (dialectDefs.empty()) 398 return false; 399 400 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end()); 401 std::optional<Dialect> dialect = findDialectToGenerate(dialects); 402 if (!dialect) 403 return true; 404 emitDialectDef(*dialect, records, os); 405 return false; 406 } 407 408 //===----------------------------------------------------------------------===// 409 // GEN: Dialect registration hooks 410 //===----------------------------------------------------------------------===// 411 412 static mlir::GenRegistration 413 genDialectDecls("gen-dialect-decls", "Generate dialect declarations", 414 [](const RecordKeeper &records, raw_ostream &os) { 415 return emitDialectDecls(records, os); 416 }); 417 418 static mlir::GenRegistration 419 genDialectDefs("gen-dialect-defs", "Generate dialect definitions", 420 [](const RecordKeeper &records, raw_ostream &os) { 421 return emitDialectDefs(records, os); 422 }); 423