xref: /llvm-project/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp (revision 69d3ba3db922fca8cfc47b5f115b6bea6a737aab)
1 //===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate IRDL
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/IRDL/IR/IRDL.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/TableGen/AttrOrTypeDef.h"
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/GenNameParser.h"
24 #include "mlir/TableGen/Interfaces.h"
25 #include "mlir/TableGen/Operator.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/InitLLVM.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "llvm/TableGen/Main.h"
31 #include "llvm/TableGen/Record.h"
32 #include "llvm/TableGen/TableGenBackend.h"
33 
34 using namespace llvm;
35 using namespace mlir;
36 using tblgen::NamedTypeConstraint;
37 
38 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-irdl-dialect");
39 llvm::cl::opt<std::string>
40     selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
41                     llvm::cl::cat(dialectGenCat), llvm::cl::Required);
42 
43 Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
44   MLIRContext *ctx = builder.getContext();
45 
46   if (pred.isCombined()) {
47     auto combiner = pred.getDef().getValueAsDef("kind")->getName();
48     if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") {
49       std::vector<Value> constraints;
50       for (auto *child : pred.getDef().getValueAsListOfDefs("children")) {
51         constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
52       }
53       if (combiner == "PredCombinerAnd") {
54         auto op =
55             builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
56         return op.getOutput();
57       }
58       auto op =
59           builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
60       return op.getOutput();
61     }
62   }
63 
64   std::string condition = pred.getCondition();
65   // Build a CPredOp to match the C constraint built.
66   irdl::CPredOp op = builder.create<irdl::CPredOp>(
67       UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
68   return op;
69 }
70 
71 Value typeToConstraint(OpBuilder &builder, Type type) {
72   MLIRContext *ctx = builder.getContext();
73   auto op =
74       builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type));
75   return op.getOutput();
76 }
77 
78 Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
79   MLIRContext *ctx = builder.getContext();
80   auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
81                                          StringAttr::get(ctx, baseClass));
82   return op.getOutput();
83 }
84 
85 std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
86   if (predRec.isSubClassOf("I")) {
87     auto width = predRec.getValueAsInt("bitwidth");
88     return IntegerType::get(ctx, width, IntegerType::Signless);
89   }
90 
91   if (predRec.isSubClassOf("SI")) {
92     auto width = predRec.getValueAsInt("bitwidth");
93     return IntegerType::get(ctx, width, IntegerType::Signed);
94   }
95 
96   if (predRec.isSubClassOf("UI")) {
97     auto width = predRec.getValueAsInt("bitwidth");
98     return IntegerType::get(ctx, width, IntegerType::Unsigned);
99   }
100 
101   // Index type
102   if (predRec.getName() == "Index") {
103     return IndexType::get(ctx);
104   }
105 
106   // Float types
107   if (predRec.isSubClassOf("F")) {
108     auto width = predRec.getValueAsInt("bitwidth");
109     switch (width) {
110     case 16:
111       return Float16Type::get(ctx);
112     case 32:
113       return Float32Type::get(ctx);
114     case 64:
115       return Float64Type::get(ctx);
116     case 80:
117       return Float80Type::get(ctx);
118     case 128:
119       return Float128Type::get(ctx);
120     }
121   }
122 
123   if (predRec.getName() == "NoneType") {
124     return NoneType::get(ctx);
125   }
126 
127   if (predRec.getName() == "BF16") {
128     return BFloat16Type::get(ctx);
129   }
130 
131   if (predRec.getName() == "TF32") {
132     return FloatTF32Type::get(ctx);
133   }
134 
135   if (predRec.getName() == "F8E4M3FN") {
136     return Float8E4M3FNType::get(ctx);
137   }
138 
139   if (predRec.getName() == "F8E5M2") {
140     return Float8E5M2Type::get(ctx);
141   }
142 
143   if (predRec.getName() == "F8E4M3") {
144     return Float8E4M3Type::get(ctx);
145   }
146 
147   if (predRec.getName() == "F8E4M3FNUZ") {
148     return Float8E4M3FNUZType::get(ctx);
149   }
150 
151   if (predRec.getName() == "F8E4M3B11FNUZ") {
152     return Float8E4M3B11FNUZType::get(ctx);
153   }
154 
155   if (predRec.getName() == "F8E5M2FNUZ") {
156     return Float8E5M2FNUZType::get(ctx);
157   }
158 
159   if (predRec.getName() == "F8E3M4") {
160     return Float8E3M4Type::get(ctx);
161   }
162 
163   if (predRec.isSubClassOf("Complex")) {
164     const Record *elementRec = predRec.getValueAsDef("elementType");
165     auto elementType = recordToType(ctx, *elementRec);
166     if (elementType.has_value()) {
167       return ComplexType::get(elementType.value());
168     }
169   }
170 
171   return std::nullopt;
172 }
173 
174 Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
175   MLIRContext *ctx = builder.getContext();
176   const Record &predRec = constraint.getDef();
177 
178   if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
179     return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
180 
181   if (predRec.getName() == "AnyType") {
182     auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
183     return op.getOutput();
184   }
185 
186   if (predRec.isSubClassOf("TypeDef")) {
187     auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
188     if (dialect == selectedDialect) {
189       std::string combined = ("!" + predRec.getValueAsString("mnemonic")).str();
190       SmallVector<FlatSymbolRefAttr> nested = {
191           SymbolRefAttr::get(ctx, combined)};
192       auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
193       auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
194       return op.getOutput();
195     }
196     std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
197     auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
198                                            StringAttr::get(ctx, typeName));
199     return op.getOutput();
200   }
201 
202   if (predRec.isSubClassOf("AnyTypeOf")) {
203     std::vector<Value> constraints;
204     for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
205       constraints.push_back(
206           createTypeConstraint(builder, tblgen::Constraint(child)));
207     }
208     auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
209     return op.getOutput();
210   }
211 
212   if (predRec.isSubClassOf("AllOfType")) {
213     std::vector<Value> constraints;
214     for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
215       constraints.push_back(
216           createTypeConstraint(builder, tblgen::Constraint(child)));
217     }
218     auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
219     return op.getOutput();
220   }
221 
222   // Integer types
223   if (predRec.getName() == "AnyInteger") {
224     auto op = builder.create<irdl::BaseOp>(
225         UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer"));
226     return op.getOutput();
227   }
228 
229   if (predRec.isSubClassOf("AnyI")) {
230     auto width = predRec.getValueAsInt("bitwidth");
231     std::vector<Value> types = {
232         typeToConstraint(builder,
233                          IntegerType::get(ctx, width, IntegerType::Signless)),
234         typeToConstraint(builder,
235                          IntegerType::get(ctx, width, IntegerType::Signed)),
236         typeToConstraint(builder,
237                          IntegerType::get(ctx, width, IntegerType::Unsigned))};
238     auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types);
239     return op.getOutput();
240   }
241 
242   auto type = recordToType(ctx, predRec);
243 
244   if (type.has_value()) {
245     return typeToConstraint(builder, type.value());
246   }
247 
248   // Confined type
249   if (predRec.isSubClassOf("ConfinedType")) {
250     std::vector<Value> constraints;
251     constraints.push_back(createTypeConstraint(
252         builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
253     for (const Record *child : predRec.getValueAsListOfDefs("predicateList")) {
254       constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
255     }
256     auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
257     return op.getOutput();
258   }
259 
260   return createPredicate(builder, constraint.getPredicate());
261 }
262 
263 Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
264   MLIRContext *ctx = builder.getContext();
265   const Record &predRec = constraint.getDef();
266 
267   if (predRec.isSubClassOf("DefaultValuedAttr") ||
268       predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
269       predRec.isSubClassOf("OptionalAttr")) {
270     return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
271   }
272 
273   if (predRec.isSubClassOf("ConfinedAttr")) {
274     std::vector<Value> constraints;
275     constraints.push_back(createAttrConstraint(
276         builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
277     for (const Record *child :
278          predRec.getValueAsListOfDefs("attrConstraints")) {
279       constraints.push_back(createPredicate(
280           builder, tblgen::Pred(child->getValueAsDef("predicate"))));
281     }
282     auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
283     return op.getOutput();
284   }
285 
286   if (predRec.isSubClassOf("AnyAttrOf")) {
287     std::vector<Value> constraints;
288     for (const Record *child :
289          predRec.getValueAsListOfDefs("allowedAttributes")) {
290       constraints.push_back(
291           createAttrConstraint(builder, tblgen::Constraint(child)));
292     }
293     auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
294     return op.getOutput();
295   }
296 
297   if (predRec.getName() == "AnyAttr") {
298     auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
299     return op.getOutput();
300   }
301 
302   if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
303       predRec.isSubClassOf("SignlessIntegerAttrBase") ||
304       predRec.isSubClassOf("SignedIntegerAttrBase") ||
305       predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
306       predRec.isSubClassOf("BoolAttr")) {
307     return baseToConstraint(builder, "!builtin.integer");
308   }
309 
310   if (predRec.isSubClassOf("FloatAttrBase")) {
311     return baseToConstraint(builder, "!builtin.float");
312   }
313 
314   if (predRec.isSubClassOf("StringBasedAttr")) {
315     return baseToConstraint(builder, "!builtin.string");
316   }
317 
318   if (predRec.getName() == "UnitAttr") {
319     auto op =
320         builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
321     return op.getOutput();
322   }
323 
324   if (predRec.isSubClassOf("AttrDef")) {
325     auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
326     if (dialect == selectedDialect) {
327       std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
328       SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
329 
330       };
331       auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
332       auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
333       return op.getOutput();
334     }
335     std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
336     auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
337                                            StringAttr::get(ctx, typeName));
338     return op.getOutput();
339   }
340 
341   return createPredicate(builder, constraint.getPredicate());
342 }
343 
344 Value createRegionConstraint(OpBuilder &builder, tblgen::Region constraint) {
345   MLIRContext *ctx = builder.getContext();
346   const Record &predRec = constraint.getDef();
347 
348   if (predRec.getName() == "AnyRegion") {
349     ValueRange entryBlockArgs = {};
350     auto op =
351         builder.create<irdl::RegionOp>(UnknownLoc::get(ctx), entryBlockArgs);
352     return op.getResult();
353   }
354 
355   if (predRec.isSubClassOf("SizedRegion")) {
356     ValueRange entryBlockArgs = {};
357     auto ty = IntegerType::get(ctx, 32);
358     auto op = builder.create<irdl::RegionOp>(
359         UnknownLoc::get(ctx), entryBlockArgs,
360         IntegerAttr::get(ty, predRec.getValueAsInt("blocks")));
361     return op.getResult();
362   }
363 
364   return createPredicate(builder, constraint.getPredicate());
365 }
366 
367 /// Returns the name of the operation without the dialect prefix.
368 static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
369   StringRef opName = tblgenOp.getDef().getValueAsString("opName");
370   return opName;
371 }
372 
373 /// Returns the name of the type without the dialect prefix.
374 static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
375   StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
376   return opName;
377 }
378 
379 /// Returns the name of the attr without the dialect prefix.
380 static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
381   StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
382   return opName;
383 }
384 
385 /// Extract an operation to IRDL.
386 irdl::OperationOp createIRDLOperation(OpBuilder &builder,
387                                       tblgen::Operator &tblgenOp) {
388   MLIRContext *ctx = builder.getContext();
389   StringRef opName = getOperatorName(tblgenOp);
390 
391   irdl::OperationOp op = builder.create<irdl::OperationOp>(
392       UnknownLoc::get(ctx), StringAttr::get(ctx, opName));
393 
394   // Add the block in the region.
395   Block &opBlock = op.getBody().emplaceBlock();
396   OpBuilder consBuilder = OpBuilder::atBlockBegin(&opBlock);
397 
398   SmallDenseSet<StringRef> usedNames;
399   for (auto &namedCons : tblgenOp.getOperands())
400     usedNames.insert(namedCons.name);
401   for (auto &namedCons : tblgenOp.getResults())
402     usedNames.insert(namedCons.name);
403   for (auto &namedReg : tblgenOp.getRegions())
404     usedNames.insert(namedReg.name);
405 
406   size_t generateCounter = 0;
407   auto generateName = [&](StringRef prefix) -> StringAttr {
408     SmallString<16> candidate;
409     do {
410       candidate.clear();
411       raw_svector_ostream candidateStream(candidate);
412       candidateStream << prefix << generateCounter;
413       generateCounter++;
414     } while (usedNames.contains(candidate));
415     return StringAttr::get(ctx, candidate);
416   };
417   auto normalizeName = [&](StringRef name) -> StringAttr {
418     if (name == "")
419       return generateName("unnamed");
420     return StringAttr::get(ctx, name);
421   };
422 
423   auto getValues = [&](tblgen::Operator::const_value_range namedCons) {
424     SmallVector<Value> operands;
425     SmallVector<Attribute> names;
426     SmallVector<irdl::VariadicityAttr> variadicity;
427 
428     for (const NamedTypeConstraint &namedCons : namedCons) {
429       auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
430       operands.push_back(operand);
431 
432       names.push_back(normalizeName(namedCons.name));
433 
434       irdl::VariadicityAttr var;
435       if (namedCons.isOptional())
436         var = consBuilder.getAttr<irdl::VariadicityAttr>(
437             irdl::Variadicity::optional);
438       else if (namedCons.isVariadic())
439         var = consBuilder.getAttr<irdl::VariadicityAttr>(
440             irdl::Variadicity::variadic);
441       else
442         var = consBuilder.getAttr<irdl::VariadicityAttr>(
443             irdl::Variadicity::single);
444 
445       variadicity.push_back(var);
446     }
447     return std::make_tuple(operands, names, variadicity);
448   };
449 
450   auto [operands, operandNames, operandVariadicity] =
451       getValues(tblgenOp.getOperands());
452   auto [results, resultNames, resultVariadicity] =
453       getValues(tblgenOp.getResults());
454 
455   SmallVector<Value> attributes;
456   SmallVector<Attribute> attrNames;
457   for (auto namedAttr : tblgenOp.getAttributes()) {
458     if (namedAttr.attr.isOptional())
459       continue;
460     attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
461     attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
462   }
463 
464   SmallVector<Value> regions;
465   SmallVector<Attribute> regionNames;
466   for (auto namedRegion : tblgenOp.getRegions()) {
467     regions.push_back(
468         createRegionConstraint(consBuilder, namedRegion.constraint));
469     regionNames.push_back(normalizeName(namedRegion.name));
470   }
471 
472   // Create the operands and results operations.
473   if (!operands.empty())
474     consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
475                                          ArrayAttr::get(ctx, operandNames),
476                                          operandVariadicity);
477   if (!results.empty())
478     consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
479                                         ArrayAttr::get(ctx, resultNames),
480                                         resultVariadicity);
481   if (!attributes.empty())
482     consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
483                                            ArrayAttr::get(ctx, attrNames));
484   if (!regions.empty())
485     consBuilder.create<irdl::RegionsOp>(UnknownLoc::get(ctx), regions,
486                                         ArrayAttr::get(ctx, regionNames));
487 
488   return op;
489 }
490 
491 irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
492   MLIRContext *ctx = builder.getContext();
493   StringRef typeName = getTypeName(tblgenType);
494   std::string combined = ("!" + typeName).str();
495 
496   irdl::TypeOp op = builder.create<irdl::TypeOp>(
497       UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
498 
499   op.getBody().emplaceBlock();
500 
501   return op;
502 }
503 
504 irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
505                                  tblgen::AttrDef &tblgenAttr) {
506   MLIRContext *ctx = builder.getContext();
507   StringRef attrName = getAttrName(tblgenAttr);
508   std::string combined = ("#" + attrName).str();
509 
510   irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
511       UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
512 
513   op.getBody().emplaceBlock();
514 
515   return op;
516 }
517 
518 static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
519   MLIRContext *ctx = builder.getContext();
520   return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
521                                          StringAttr::get(ctx, selectedDialect));
522 }
523 
524 static bool emitDialectIRDLDefs(const RecordKeeper &records, raw_ostream &os) {
525   // Initialize.
526   MLIRContext ctx;
527   ctx.getOrLoadDialect<irdl::IRDLDialect>();
528   OpBuilder builder(&ctx);
529 
530   // Create a module op and set it as the insertion point.
531   OwningOpRef<ModuleOp> module =
532       builder.create<ModuleOp>(UnknownLoc::get(&ctx));
533   builder = builder.atBlockBegin(module->getBody());
534   // Create the dialect and insert it.
535   irdl::DialectOp dialect = createIRDLDialect(builder);
536   // Set insertion point to start of DialectOp.
537   builder = builder.atBlockBegin(&dialect.getBody().emplaceBlock());
538 
539   for (const Record *type :
540        records.getAllDerivedDefinitionsIfDefined("TypeDef")) {
541     tblgen::TypeDef tblgenType(type);
542     if (tblgenType.getDialect().getName() != selectedDialect)
543       continue;
544     createIRDLType(builder, tblgenType);
545   }
546 
547   for (const Record *attr :
548        records.getAllDerivedDefinitionsIfDefined("AttrDef")) {
549     tblgen::AttrDef tblgenAttr(attr);
550     if (tblgenAttr.getDialect().getName() != selectedDialect)
551       continue;
552     createIRDLAttr(builder, tblgenAttr);
553   }
554 
555   for (const Record *def : records.getAllDerivedDefinitionsIfDefined("Op")) {
556     tblgen::Operator tblgenOp(def);
557     if (tblgenOp.getDialectName() != selectedDialect)
558       continue;
559 
560     createIRDLOperation(builder, tblgenOp);
561   }
562 
563   // Print the module.
564   module->print(os);
565 
566   return false;
567 }
568 
569 static mlir::GenRegistration
570     genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions",
571               [](const RecordKeeper &records, raw_ostream &os) {
572                 return emitDialectIRDLDefs(records, os);
573               });
574