xref: /llvm-project/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp (revision e813750354bbc08551cf23ff559a54b4a9ea1f29)
1 //===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
10 // generate the corresponding Python binding classes.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "OpGenHelpers.h"
14 
15 #include "mlir/TableGen/AttrOrTypeDef.h"
16 #include "mlir/TableGen/Attribute.h"
17 #include "mlir/TableGen/Dialect.h"
18 #include "mlir/TableGen/GenInfo.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 using llvm::formatv;
25 using llvm::Record;
26 using llvm::RecordKeeper;
27 
28 /// File header and includes.
29 constexpr const char *fileHeader = R"Py(
30 # Autogenerated by mlir-tblgen; don't manually edit.
31 
32 from enum import IntEnum, auto, IntFlag
33 from ._ods_common import _cext as _ods_cext
34 from ..ir import register_attribute_builder
35 _ods_ir = _ods_cext.ir
36 
37 )Py";
38 
39 /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
40 static std::string makePythonEnumCaseName(StringRef name) {
41   if (isPythonReserved(name.str()))
42     return (name + "_").str();
43   return name.str();
44 }
45 
46 /// Emits the Python class for the given enum.
47 static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
48   os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
49                 enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
50   if (!enumAttr.getSummary().empty())
51     os << formatv("    \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
52   os << "\n";
53 
54   for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
55     os << formatv("    {0} = {1}\n",
56                   makePythonEnumCaseName(enumCase.getSymbol()),
57                   enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
58                                            : "auto()");
59   }
60 
61   os << "\n";
62 
63   if (enumAttr.isBitEnum()) {
64     os << formatv("    def __iter__(self):\n"
65                   "        return iter([case for case in type(self) if "
66                   "(self & case) is case])\n");
67     os << formatv("    def __len__(self):\n"
68                   "        return bin(self).count(\"1\")\n");
69     os << "\n";
70   }
71 
72   os << formatv("    def __str__(self):\n");
73   if (enumAttr.isBitEnum())
74     os << formatv("        if len(self) > 1:\n"
75                   "            return \"{0}\".join(map(str, self))\n",
76                   enumAttr.getDef().getValueAsString("separator"));
77   for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
78     os << formatv("        if self is {0}.{1}:\n", enumAttr.getEnumClassName(),
79                   makePythonEnumCaseName(enumCase.getSymbol()));
80     os << formatv("            return \"{0}\"\n", enumCase.getStr());
81   }
82   os << formatv("        raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
83                 enumAttr.getEnumClassName());
84   os << "\n";
85 }
86 
87 /// Attempts to extract the bitwidth B from string "uintB_t" describing the
88 /// type. This bitwidth information is not readily available in ODS. Returns
89 /// `false` on success, `true` on failure.
90 static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
91   if (!uintType.consume_front("uint"))
92     return true;
93   if (!uintType.consume_back("_t"))
94     return true;
95   return uintType.getAsInteger(/*Radix=*/10, bitwidth);
96 }
97 
98 /// Emits an attribute builder for the given enum attribute to support automatic
99 /// conversion between enum values and attributes in Python. Returns
100 /// `false` on success, `true` on failure.
101 static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
102   int64_t bitwidth;
103   if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
104     llvm::errs() << "failed to identify bitwidth of "
105                  << enumAttr.getUnderlyingType();
106     return true;
107   }
108 
109   os << formatv("@register_attribute_builder(\"{0}\")\n",
110                 enumAttr.getAttrDefName());
111   os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
112   os << formatv("    return "
113                 "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
114                 "context=context), int(x))\n\n",
115                 bitwidth);
116   return false;
117 }
118 
119 /// Emits an attribute builder for the given dialect enum attribute to support
120 /// automatic conversion between enum values and attributes in Python. Returns
121 /// `false` on success, `true` on failure.
122 static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
123                                             StringRef formatString,
124                                             raw_ostream &os) {
125   os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
126   os << formatv("def _{0}(x, context):\n", attrDefName.lower());
127   os << formatv("    return "
128                 "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
129                 formatString);
130   return false;
131 }
132 
133 /// Emits Python bindings for all enums in the record keeper. Returns
134 /// `false` on success, `true` on failure.
135 static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
136   os << fileHeader;
137   for (const Record *it :
138        records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
139     EnumAttr enumAttr(*it);
140     emitEnumClass(enumAttr, os);
141     emitAttributeBuilder(enumAttr, os);
142   }
143   for (const Record *it :
144        records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
145     AttrOrTypeDef attr(&*it);
146     if (!attr.getMnemonic()) {
147       llvm::errs() << "enum case " << attr
148                    << " needs mnemonic for python enum bindings generation";
149       return true;
150     }
151     StringRef mnemonic = attr.getMnemonic().value();
152     std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
153     StringRef dialect = attr.getDialect().getName();
154     if (assemblyFormat == "`<` $value `>`") {
155       emitDialectEnumAttributeBuilder(
156           attr.getName(),
157           formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
158     } else if (assemblyFormat == "$value") {
159       emitDialectEnumAttributeBuilder(
160           attr.getName(),
161           formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
162     } else {
163       llvm::errs()
164           << "unsupported assembly format for python enum bindings generation";
165       return true;
166     }
167   }
168 
169   return false;
170 }
171 
172 // Registers the enum utility generator to mlir-tblgen.
173 static mlir::GenRegistration
174     genPythonEnumBindings("gen-python-enum-bindings",
175                           "Generate Python bindings for enum attributes",
176                           &emitPythonEnums);
177