xref: /llvm-project/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp (revision e813750354bbc08551cf23ff559a54b4a9ea1f29)
11f8618f8SAlex Zinenko //===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
21f8618f8SAlex Zinenko //
31f8618f8SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41f8618f8SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
51f8618f8SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61f8618f8SAlex Zinenko //
71f8618f8SAlex Zinenko //===----------------------------------------------------------------------===//
81f8618f8SAlex Zinenko //
91f8618f8SAlex Zinenko // EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
101f8618f8SAlex Zinenko // generate the corresponding Python binding classes.
111f8618f8SAlex Zinenko //
121f8618f8SAlex Zinenko //===----------------------------------------------------------------------===//
1392233062Smax #include "OpGenHelpers.h"
141f8618f8SAlex Zinenko 
1592233062Smax #include "mlir/TableGen/AttrOrTypeDef.h"
161f8618f8SAlex Zinenko #include "mlir/TableGen/Attribute.h"
1792233062Smax #include "mlir/TableGen/Dialect.h"
181f8618f8SAlex Zinenko #include "mlir/TableGen/GenInfo.h"
191f8618f8SAlex Zinenko #include "llvm/Support/FormatVariadic.h"
201f8618f8SAlex Zinenko #include "llvm/TableGen/Record.h"
211f8618f8SAlex Zinenko 
221f8618f8SAlex Zinenko using namespace mlir;
231f8618f8SAlex Zinenko using namespace mlir::tblgen;
24bccd37f6SRahul Joshi using llvm::formatv;
25bccd37f6SRahul Joshi using llvm::Record;
26bccd37f6SRahul Joshi using llvm::RecordKeeper;
271f8618f8SAlex Zinenko 
281f8618f8SAlex Zinenko /// File header and includes.
291f8618f8SAlex Zinenko constexpr const char *fileHeader = R"Py(
301f8618f8SAlex Zinenko # Autogenerated by mlir-tblgen; don't manually edit.
311f8618f8SAlex Zinenko 
3292233062Smax from enum import IntEnum, auto, IntFlag
331f8618f8SAlex Zinenko from ._ods_common import _cext as _ods_cext
3492233062Smax from ..ir import register_attribute_builder
351f8618f8SAlex Zinenko _ods_ir = _ods_cext.ir
361f8618f8SAlex Zinenko 
371f8618f8SAlex Zinenko )Py";
381f8618f8SAlex Zinenko 
391f8618f8SAlex Zinenko /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
401f8618f8SAlex Zinenko static std::string makePythonEnumCaseName(StringRef name) {
4192233062Smax   if (isPythonReserved(name.str()))
4292233062Smax     return (name + "_").str();
4392233062Smax   return name.str();
441f8618f8SAlex Zinenko }
451f8618f8SAlex Zinenko 
461f8618f8SAlex Zinenko /// Emits the Python class for the given enum.
4792233062Smax static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
48bccd37f6SRahul Joshi   os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
4992233062Smax                 enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
5092233062Smax   if (!enumAttr.getSummary().empty())
51bccd37f6SRahul Joshi     os << formatv("    \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
521f8618f8SAlex Zinenko   os << "\n";
531f8618f8SAlex Zinenko 
5492233062Smax   for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
55bccd37f6SRahul Joshi     os << formatv("    {0} = {1}\n",
56bccd37f6SRahul Joshi                   makePythonEnumCaseName(enumCase.getSymbol()),
5792233062Smax                   enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
5892233062Smax                                            : "auto()");
591f8618f8SAlex Zinenko   }
601f8618f8SAlex Zinenko 
611f8618f8SAlex Zinenko   os << "\n";
6292233062Smax 
6392233062Smax   if (enumAttr.isBitEnum()) {
64bccd37f6SRahul Joshi     os << formatv("    def __iter__(self):\n"
6592233062Smax                   "        return iter([case for case in type(self) if "
6692233062Smax                   "(self & case) is case])\n");
67bccd37f6SRahul Joshi     os << formatv("    def __len__(self):\n"
6892233062Smax                   "        return bin(self).count(\"1\")\n");
6992233062Smax     os << "\n";
7092233062Smax   }
7192233062Smax 
72bccd37f6SRahul Joshi   os << formatv("    def __str__(self):\n");
7392233062Smax   if (enumAttr.isBitEnum())
74bccd37f6SRahul Joshi     os << formatv("        if len(self) > 1:\n"
7592233062Smax                   "            return \"{0}\".join(map(str, self))\n",
7692233062Smax                   enumAttr.getDef().getValueAsString("separator"));
7792233062Smax   for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
78bccd37f6SRahul Joshi     os << formatv("        if self is {0}.{1}:\n", enumAttr.getEnumClassName(),
791f8618f8SAlex Zinenko                   makePythonEnumCaseName(enumCase.getSymbol()));
80bccd37f6SRahul Joshi     os << formatv("            return \"{0}\"\n", enumCase.getStr());
811f8618f8SAlex Zinenko   }
82bccd37f6SRahul Joshi   os << formatv("        raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
8392233062Smax                 enumAttr.getEnumClassName());
8492233062Smax   os << "\n";
851f8618f8SAlex Zinenko }
861f8618f8SAlex Zinenko 
871f8618f8SAlex Zinenko /// Attempts to extract the bitwidth B from string "uintB_t" describing the
881f8618f8SAlex Zinenko /// type. This bitwidth information is not readily available in ODS. Returns
891f8618f8SAlex Zinenko /// `false` on success, `true` on failure.
901f8618f8SAlex Zinenko static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
911f8618f8SAlex Zinenko   if (!uintType.consume_front("uint"))
921f8618f8SAlex Zinenko     return true;
931f8618f8SAlex Zinenko   if (!uintType.consume_back("_t"))
941f8618f8SAlex Zinenko     return true;
951f8618f8SAlex Zinenko   return uintType.getAsInteger(/*Radix=*/10, bitwidth);
961f8618f8SAlex Zinenko }
971f8618f8SAlex Zinenko 
981f8618f8SAlex Zinenko /// Emits an attribute builder for the given enum attribute to support automatic
991f8618f8SAlex Zinenko /// conversion between enum values and attributes in Python. Returns
1001f8618f8SAlex Zinenko /// `false` on success, `true` on failure.
1011f8618f8SAlex Zinenko static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
1021f8618f8SAlex Zinenko   int64_t bitwidth;
1031f8618f8SAlex Zinenko   if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
1041f8618f8SAlex Zinenko     llvm::errs() << "failed to identify bitwidth of "
1051f8618f8SAlex Zinenko                  << enumAttr.getUnderlyingType();
1061f8618f8SAlex Zinenko     return true;
1071f8618f8SAlex Zinenko   }
1081f8618f8SAlex Zinenko 
109bccd37f6SRahul Joshi   os << formatv("@register_attribute_builder(\"{0}\")\n",
1101f8618f8SAlex Zinenko                 enumAttr.getAttrDefName());
111bccd37f6SRahul Joshi   os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
112bccd37f6SRahul Joshi   os << formatv("    return "
1131f8618f8SAlex Zinenko                 "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
11492233062Smax                 "context=context), int(x))\n\n",
1151f8618f8SAlex Zinenko                 bitwidth);
1161f8618f8SAlex Zinenko   return false;
1171f8618f8SAlex Zinenko }
1181f8618f8SAlex Zinenko 
11992233062Smax /// Emits an attribute builder for the given dialect enum attribute to support
12092233062Smax /// automatic conversion between enum values and attributes in Python. Returns
12192233062Smax /// `false` on success, `true` on failure.
12292233062Smax static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
12392233062Smax                                             StringRef formatString,
12492233062Smax                                             raw_ostream &os) {
125bccd37f6SRahul Joshi   os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
126bccd37f6SRahul Joshi   os << formatv("def _{0}(x, context):\n", attrDefName.lower());
127bccd37f6SRahul Joshi   os << formatv("    return "
12892233062Smax                 "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
12992233062Smax                 formatString);
13092233062Smax   return false;
13192233062Smax }
13292233062Smax 
1331f8618f8SAlex Zinenko /// Emits Python bindings for all enums in the record keeper. Returns
1341f8618f8SAlex Zinenko /// `false` on success, `true` on failure.
135*e8137503SRahul Joshi static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
1361f8618f8SAlex Zinenko   os << fileHeader;
137bccd37f6SRahul Joshi   for (const Record *it :
138*e8137503SRahul Joshi        records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
13992233062Smax     EnumAttr enumAttr(*it);
14092233062Smax     emitEnumClass(enumAttr, os);
1411f8618f8SAlex Zinenko     emitAttributeBuilder(enumAttr, os);
1421f8618f8SAlex Zinenko   }
143bccd37f6SRahul Joshi   for (const Record *it :
144*e8137503SRahul Joshi        records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
14592233062Smax     AttrOrTypeDef attr(&*it);
14692233062Smax     if (!attr.getMnemonic()) {
14792233062Smax       llvm::errs() << "enum case " << attr
14892233062Smax                    << " needs mnemonic for python enum bindings generation";
14992233062Smax       return true;
15092233062Smax     }
15192233062Smax     StringRef mnemonic = attr.getMnemonic().value();
15292233062Smax     std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
15392233062Smax     StringRef dialect = attr.getDialect().getName();
15492233062Smax     if (assemblyFormat == "`<` $value `>`") {
15592233062Smax       emitDialectEnumAttributeBuilder(
15692233062Smax           attr.getName(),
157bccd37f6SRahul Joshi           formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
15892233062Smax     } else if (assemblyFormat == "$value") {
15992233062Smax       emitDialectEnumAttributeBuilder(
16092233062Smax           attr.getName(),
161bccd37f6SRahul Joshi           formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
16292233062Smax     } else {
16392233062Smax       llvm::errs()
16492233062Smax           << "unsupported assembly format for python enum bindings generation";
16592233062Smax       return true;
16692233062Smax     }
16792233062Smax   }
16892233062Smax 
1691f8618f8SAlex Zinenko   return false;
1701f8618f8SAlex Zinenko }
1711f8618f8SAlex Zinenko 
1721f8618f8SAlex Zinenko // Registers the enum utility generator to mlir-tblgen.
1731f8618f8SAlex Zinenko static mlir::GenRegistration
1741f8618f8SAlex Zinenko     genPythonEnumBindings("gen-python-enum-bindings",
1751f8618f8SAlex Zinenko                           "Generate Python bindings for enum attributes",
1761f8618f8SAlex Zinenko                           &emitPythonEnums);
177