1 //===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpDefinitionsGen uses the description of operations to generate IRDL 10 // definitions for ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/IRDL/IR/IRDL.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/Diagnostics.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/IR/MLIRContext.h" 21 #include "mlir/TableGen/AttrOrTypeDef.h" 22 #include "mlir/TableGen/GenInfo.h" 23 #include "mlir/TableGen/GenNameParser.h" 24 #include "mlir/TableGen/Interfaces.h" 25 #include "mlir/TableGen/Operator.h" 26 #include "llvm/ADT/StringExtras.h" 27 #include "llvm/Support/CommandLine.h" 28 #include "llvm/Support/InitLLVM.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include "llvm/TableGen/Main.h" 31 #include "llvm/TableGen/Record.h" 32 #include "llvm/TableGen/TableGenBackend.h" 33 34 using namespace llvm; 35 using namespace mlir; 36 using tblgen::NamedTypeConstraint; 37 38 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-irdl-dialect"); 39 llvm::cl::opt<std::string> 40 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"), 41 llvm::cl::cat(dialectGenCat), llvm::cl::Required); 42 43 Value createPredicate(OpBuilder &builder, tblgen::Pred pred) { 44 MLIRContext *ctx = builder.getContext(); 45 46 if (pred.isCombined()) { 47 auto combiner = pred.getDef().getValueAsDef("kind")->getName(); 48 if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") { 49 std::vector<Value> constraints; 50 for (auto *child : pred.getDef().getValueAsListOfDefs("children")) { 51 constraints.push_back(createPredicate(builder, tblgen::Pred(child))); 52 } 53 if (combiner == "PredCombinerAnd") { 54 auto op = 55 builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); 56 return op.getOutput(); 57 } 58 auto op = 59 builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints); 60 return op.getOutput(); 61 } 62 } 63 64 std::string condition = pred.getCondition(); 65 // Build a CPredOp to match the C constraint built. 66 irdl::CPredOp op = builder.create<irdl::CPredOp>( 67 UnknownLoc::get(ctx), StringAttr::get(ctx, condition)); 68 return op; 69 } 70 71 Value typeToConstraint(OpBuilder &builder, Type type) { 72 MLIRContext *ctx = builder.getContext(); 73 auto op = 74 builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type)); 75 return op.getOutput(); 76 } 77 78 Value baseToConstraint(OpBuilder &builder, StringRef baseClass) { 79 MLIRContext *ctx = builder.getContext(); 80 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), 81 StringAttr::get(ctx, baseClass)); 82 return op.getOutput(); 83 } 84 85 std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) { 86 if (predRec.isSubClassOf("I")) { 87 auto width = predRec.getValueAsInt("bitwidth"); 88 return IntegerType::get(ctx, width, IntegerType::Signless); 89 } 90 91 if (predRec.isSubClassOf("SI")) { 92 auto width = predRec.getValueAsInt("bitwidth"); 93 return IntegerType::get(ctx, width, IntegerType::Signed); 94 } 95 96 if (predRec.isSubClassOf("UI")) { 97 auto width = predRec.getValueAsInt("bitwidth"); 98 return IntegerType::get(ctx, width, IntegerType::Unsigned); 99 } 100 101 // Index type 102 if (predRec.getName() == "Index") { 103 return IndexType::get(ctx); 104 } 105 106 // Float types 107 if (predRec.isSubClassOf("F")) { 108 auto width = predRec.getValueAsInt("bitwidth"); 109 switch (width) { 110 case 16: 111 return Float16Type::get(ctx); 112 case 32: 113 return Float32Type::get(ctx); 114 case 64: 115 return Float64Type::get(ctx); 116 case 80: 117 return Float80Type::get(ctx); 118 case 128: 119 return Float128Type::get(ctx); 120 } 121 } 122 123 if (predRec.getName() == "NoneType") { 124 return NoneType::get(ctx); 125 } 126 127 if (predRec.getName() == "BF16") { 128 return BFloat16Type::get(ctx); 129 } 130 131 if (predRec.getName() == "TF32") { 132 return FloatTF32Type::get(ctx); 133 } 134 135 if (predRec.getName() == "F8E4M3FN") { 136 return Float8E4M3FNType::get(ctx); 137 } 138 139 if (predRec.getName() == "F8E5M2") { 140 return Float8E5M2Type::get(ctx); 141 } 142 143 if (predRec.getName() == "F8E4M3") { 144 return Float8E4M3Type::get(ctx); 145 } 146 147 if (predRec.getName() == "F8E4M3FNUZ") { 148 return Float8E4M3FNUZType::get(ctx); 149 } 150 151 if (predRec.getName() == "F8E4M3B11FNUZ") { 152 return Float8E4M3B11FNUZType::get(ctx); 153 } 154 155 if (predRec.getName() == "F8E5M2FNUZ") { 156 return Float8E5M2FNUZType::get(ctx); 157 } 158 159 if (predRec.getName() == "F8E3M4") { 160 return Float8E3M4Type::get(ctx); 161 } 162 163 if (predRec.isSubClassOf("Complex")) { 164 const Record *elementRec = predRec.getValueAsDef("elementType"); 165 auto elementType = recordToType(ctx, *elementRec); 166 if (elementType.has_value()) { 167 return ComplexType::get(elementType.value()); 168 } 169 } 170 171 return std::nullopt; 172 } 173 174 Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) { 175 MLIRContext *ctx = builder.getContext(); 176 const Record &predRec = constraint.getDef(); 177 178 if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional")) 179 return createTypeConstraint(builder, predRec.getValueAsDef("baseType")); 180 181 if (predRec.getName() == "AnyType") { 182 auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx)); 183 return op.getOutput(); 184 } 185 186 if (predRec.isSubClassOf("TypeDef")) { 187 auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name"); 188 if (dialect == selectedDialect) { 189 std::string combined = ("!" + predRec.getValueAsString("mnemonic")).str(); 190 SmallVector<FlatSymbolRefAttr> nested = { 191 SymbolRefAttr::get(ctx, combined)}; 192 auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested); 193 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol); 194 return op.getOutput(); 195 } 196 std::string typeName = ("!" + predRec.getValueAsString("typeName")).str(); 197 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), 198 StringAttr::get(ctx, typeName)); 199 return op.getOutput(); 200 } 201 202 if (predRec.isSubClassOf("AnyTypeOf")) { 203 std::vector<Value> constraints; 204 for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) { 205 constraints.push_back( 206 createTypeConstraint(builder, tblgen::Constraint(child))); 207 } 208 auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints); 209 return op.getOutput(); 210 } 211 212 if (predRec.isSubClassOf("AllOfType")) { 213 std::vector<Value> constraints; 214 for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) { 215 constraints.push_back( 216 createTypeConstraint(builder, tblgen::Constraint(child))); 217 } 218 auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); 219 return op.getOutput(); 220 } 221 222 // Integer types 223 if (predRec.getName() == "AnyInteger") { 224 auto op = builder.create<irdl::BaseOp>( 225 UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer")); 226 return op.getOutput(); 227 } 228 229 if (predRec.isSubClassOf("AnyI")) { 230 auto width = predRec.getValueAsInt("bitwidth"); 231 std::vector<Value> types = { 232 typeToConstraint(builder, 233 IntegerType::get(ctx, width, IntegerType::Signless)), 234 typeToConstraint(builder, 235 IntegerType::get(ctx, width, IntegerType::Signed)), 236 typeToConstraint(builder, 237 IntegerType::get(ctx, width, IntegerType::Unsigned))}; 238 auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types); 239 return op.getOutput(); 240 } 241 242 auto type = recordToType(ctx, predRec); 243 244 if (type.has_value()) { 245 return typeToConstraint(builder, type.value()); 246 } 247 248 // Confined type 249 if (predRec.isSubClassOf("ConfinedType")) { 250 std::vector<Value> constraints; 251 constraints.push_back(createTypeConstraint( 252 builder, tblgen::Constraint(predRec.getValueAsDef("baseType")))); 253 for (const Record *child : predRec.getValueAsListOfDefs("predicateList")) { 254 constraints.push_back(createPredicate(builder, tblgen::Pred(child))); 255 } 256 auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); 257 return op.getOutput(); 258 } 259 260 return createPredicate(builder, constraint.getPredicate()); 261 } 262 263 Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) { 264 MLIRContext *ctx = builder.getContext(); 265 const Record &predRec = constraint.getDef(); 266 267 if (predRec.isSubClassOf("DefaultValuedAttr") || 268 predRec.isSubClassOf("DefaultValuedOptionalAttr") || 269 predRec.isSubClassOf("OptionalAttr")) { 270 return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr")); 271 } 272 273 if (predRec.isSubClassOf("ConfinedAttr")) { 274 std::vector<Value> constraints; 275 constraints.push_back(createAttrConstraint( 276 builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr")))); 277 for (const Record *child : 278 predRec.getValueAsListOfDefs("attrConstraints")) { 279 constraints.push_back(createPredicate( 280 builder, tblgen::Pred(child->getValueAsDef("predicate")))); 281 } 282 auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); 283 return op.getOutput(); 284 } 285 286 if (predRec.isSubClassOf("AnyAttrOf")) { 287 std::vector<Value> constraints; 288 for (const Record *child : 289 predRec.getValueAsListOfDefs("allowedAttributes")) { 290 constraints.push_back( 291 createAttrConstraint(builder, tblgen::Constraint(child))); 292 } 293 auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints); 294 return op.getOutput(); 295 } 296 297 if (predRec.getName() == "AnyAttr") { 298 auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx)); 299 return op.getOutput(); 300 } 301 302 if (predRec.isSubClassOf("AnyIntegerAttrBase") || 303 predRec.isSubClassOf("SignlessIntegerAttrBase") || 304 predRec.isSubClassOf("SignedIntegerAttrBase") || 305 predRec.isSubClassOf("UnsignedIntegerAttrBase") || 306 predRec.isSubClassOf("BoolAttr")) { 307 return baseToConstraint(builder, "!builtin.integer"); 308 } 309 310 if (predRec.isSubClassOf("FloatAttrBase")) { 311 return baseToConstraint(builder, "!builtin.float"); 312 } 313 314 if (predRec.isSubClassOf("StringBasedAttr")) { 315 return baseToConstraint(builder, "!builtin.string"); 316 } 317 318 if (predRec.getName() == "UnitAttr") { 319 auto op = 320 builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx)); 321 return op.getOutput(); 322 } 323 324 if (predRec.isSubClassOf("AttrDef")) { 325 auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name"); 326 if (dialect == selectedDialect) { 327 std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str(); 328 SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined) 329 330 }; 331 auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested); 332 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol); 333 return op.getOutput(); 334 } 335 std::string typeName = ("#" + predRec.getValueAsString("attrName")).str(); 336 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), 337 StringAttr::get(ctx, typeName)); 338 return op.getOutput(); 339 } 340 341 return createPredicate(builder, constraint.getPredicate()); 342 } 343 344 Value createRegionConstraint(OpBuilder &builder, tblgen::Region constraint) { 345 MLIRContext *ctx = builder.getContext(); 346 const Record &predRec = constraint.getDef(); 347 348 if (predRec.getName() == "AnyRegion") { 349 ValueRange entryBlockArgs = {}; 350 auto op = 351 builder.create<irdl::RegionOp>(UnknownLoc::get(ctx), entryBlockArgs); 352 return op.getResult(); 353 } 354 355 if (predRec.isSubClassOf("SizedRegion")) { 356 ValueRange entryBlockArgs = {}; 357 auto ty = IntegerType::get(ctx, 32); 358 auto op = builder.create<irdl::RegionOp>( 359 UnknownLoc::get(ctx), entryBlockArgs, 360 IntegerAttr::get(ty, predRec.getValueAsInt("blocks"))); 361 return op.getResult(); 362 } 363 364 return createPredicate(builder, constraint.getPredicate()); 365 } 366 367 /// Returns the name of the operation without the dialect prefix. 368 static StringRef getOperatorName(tblgen::Operator &tblgenOp) { 369 StringRef opName = tblgenOp.getDef().getValueAsString("opName"); 370 return opName; 371 } 372 373 /// Returns the name of the type without the dialect prefix. 374 static StringRef getTypeName(tblgen::TypeDef &tblgenType) { 375 StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic"); 376 return opName; 377 } 378 379 /// Returns the name of the attr without the dialect prefix. 380 static StringRef getAttrName(tblgen::AttrDef &tblgenType) { 381 StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic"); 382 return opName; 383 } 384 385 /// Extract an operation to IRDL. 386 irdl::OperationOp createIRDLOperation(OpBuilder &builder, 387 tblgen::Operator &tblgenOp) { 388 MLIRContext *ctx = builder.getContext(); 389 StringRef opName = getOperatorName(tblgenOp); 390 391 irdl::OperationOp op = builder.create<irdl::OperationOp>( 392 UnknownLoc::get(ctx), StringAttr::get(ctx, opName)); 393 394 // Add the block in the region. 395 Block &opBlock = op.getBody().emplaceBlock(); 396 OpBuilder consBuilder = OpBuilder::atBlockBegin(&opBlock); 397 398 SmallDenseSet<StringRef> usedNames; 399 for (auto &namedCons : tblgenOp.getOperands()) 400 usedNames.insert(namedCons.name); 401 for (auto &namedCons : tblgenOp.getResults()) 402 usedNames.insert(namedCons.name); 403 for (auto &namedReg : tblgenOp.getRegions()) 404 usedNames.insert(namedReg.name); 405 406 size_t generateCounter = 0; 407 auto generateName = [&](StringRef prefix) -> StringAttr { 408 SmallString<16> candidate; 409 do { 410 candidate.clear(); 411 raw_svector_ostream candidateStream(candidate); 412 candidateStream << prefix << generateCounter; 413 generateCounter++; 414 } while (usedNames.contains(candidate)); 415 return StringAttr::get(ctx, candidate); 416 }; 417 auto normalizeName = [&](StringRef name) -> StringAttr { 418 if (name == "") 419 return generateName("unnamed"); 420 return StringAttr::get(ctx, name); 421 }; 422 423 auto getValues = [&](tblgen::Operator::const_value_range namedCons) { 424 SmallVector<Value> operands; 425 SmallVector<Attribute> names; 426 SmallVector<irdl::VariadicityAttr> variadicity; 427 428 for (const NamedTypeConstraint &namedCons : namedCons) { 429 auto operand = createTypeConstraint(consBuilder, namedCons.constraint); 430 operands.push_back(operand); 431 432 names.push_back(normalizeName(namedCons.name)); 433 434 irdl::VariadicityAttr var; 435 if (namedCons.isOptional()) 436 var = consBuilder.getAttr<irdl::VariadicityAttr>( 437 irdl::Variadicity::optional); 438 else if (namedCons.isVariadic()) 439 var = consBuilder.getAttr<irdl::VariadicityAttr>( 440 irdl::Variadicity::variadic); 441 else 442 var = consBuilder.getAttr<irdl::VariadicityAttr>( 443 irdl::Variadicity::single); 444 445 variadicity.push_back(var); 446 } 447 return std::make_tuple(operands, names, variadicity); 448 }; 449 450 auto [operands, operandNames, operandVariadicity] = 451 getValues(tblgenOp.getOperands()); 452 auto [results, resultNames, resultVariadicity] = 453 getValues(tblgenOp.getResults()); 454 455 SmallVector<Value> attributes; 456 SmallVector<Attribute> attrNames; 457 for (auto namedAttr : tblgenOp.getAttributes()) { 458 if (namedAttr.attr.isOptional()) 459 continue; 460 attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr)); 461 attrNames.push_back(StringAttr::get(ctx, namedAttr.name)); 462 } 463 464 SmallVector<Value> regions; 465 SmallVector<Attribute> regionNames; 466 for (auto namedRegion : tblgenOp.getRegions()) { 467 regions.push_back( 468 createRegionConstraint(consBuilder, namedRegion.constraint)); 469 regionNames.push_back(normalizeName(namedRegion.name)); 470 } 471 472 // Create the operands and results operations. 473 if (!operands.empty()) 474 consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands, 475 ArrayAttr::get(ctx, operandNames), 476 operandVariadicity); 477 if (!results.empty()) 478 consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results, 479 ArrayAttr::get(ctx, resultNames), 480 resultVariadicity); 481 if (!attributes.empty()) 482 consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes, 483 ArrayAttr::get(ctx, attrNames)); 484 if (!regions.empty()) 485 consBuilder.create<irdl::RegionsOp>(UnknownLoc::get(ctx), regions, 486 ArrayAttr::get(ctx, regionNames)); 487 488 return op; 489 } 490 491 irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) { 492 MLIRContext *ctx = builder.getContext(); 493 StringRef typeName = getTypeName(tblgenType); 494 std::string combined = ("!" + typeName).str(); 495 496 irdl::TypeOp op = builder.create<irdl::TypeOp>( 497 UnknownLoc::get(ctx), StringAttr::get(ctx, combined)); 498 499 op.getBody().emplaceBlock(); 500 501 return op; 502 } 503 504 irdl::AttributeOp createIRDLAttr(OpBuilder &builder, 505 tblgen::AttrDef &tblgenAttr) { 506 MLIRContext *ctx = builder.getContext(); 507 StringRef attrName = getAttrName(tblgenAttr); 508 std::string combined = ("#" + attrName).str(); 509 510 irdl::AttributeOp op = builder.create<irdl::AttributeOp>( 511 UnknownLoc::get(ctx), StringAttr::get(ctx, combined)); 512 513 op.getBody().emplaceBlock(); 514 515 return op; 516 } 517 518 static irdl::DialectOp createIRDLDialect(OpBuilder &builder) { 519 MLIRContext *ctx = builder.getContext(); 520 return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx), 521 StringAttr::get(ctx, selectedDialect)); 522 } 523 524 static bool emitDialectIRDLDefs(const RecordKeeper &records, raw_ostream &os) { 525 // Initialize. 526 MLIRContext ctx; 527 ctx.getOrLoadDialect<irdl::IRDLDialect>(); 528 OpBuilder builder(&ctx); 529 530 // Create a module op and set it as the insertion point. 531 OwningOpRef<ModuleOp> module = 532 builder.create<ModuleOp>(UnknownLoc::get(&ctx)); 533 builder = builder.atBlockBegin(module->getBody()); 534 // Create the dialect and insert it. 535 irdl::DialectOp dialect = createIRDLDialect(builder); 536 // Set insertion point to start of DialectOp. 537 builder = builder.atBlockBegin(&dialect.getBody().emplaceBlock()); 538 539 for (const Record *type : 540 records.getAllDerivedDefinitionsIfDefined("TypeDef")) { 541 tblgen::TypeDef tblgenType(type); 542 if (tblgenType.getDialect().getName() != selectedDialect) 543 continue; 544 createIRDLType(builder, tblgenType); 545 } 546 547 for (const Record *attr : 548 records.getAllDerivedDefinitionsIfDefined("AttrDef")) { 549 tblgen::AttrDef tblgenAttr(attr); 550 if (tblgenAttr.getDialect().getName() != selectedDialect) 551 continue; 552 createIRDLAttr(builder, tblgenAttr); 553 } 554 555 for (const Record *def : records.getAllDerivedDefinitionsIfDefined("Op")) { 556 tblgen::Operator tblgenOp(def); 557 if (tblgenOp.getDialectName() != selectedDialect) 558 continue; 559 560 createIRDLOperation(builder, tblgenOp); 561 } 562 563 // Print the module. 564 module->print(os); 565 566 return false; 567 } 568 569 static mlir::GenRegistration 570 genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions", 571 [](const RecordKeeper &records, raw_ostream &os) { 572 return emitDialectIRDLDefs(records, os); 573 }); 574