1 //===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // EnumPythonBindingGen uses ODS specification of MLIR enum attributes to 10 // generate the corresponding Python binding classes. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "OpGenHelpers.h" 14 15 #include "mlir/TableGen/AttrOrTypeDef.h" 16 #include "mlir/TableGen/Attribute.h" 17 #include "mlir/TableGen/Dialect.h" 18 #include "mlir/TableGen/GenInfo.h" 19 #include "llvm/Support/FormatVariadic.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 using llvm::formatv; 25 using llvm::Record; 26 using llvm::RecordKeeper; 27 28 /// File header and includes. 29 constexpr const char *fileHeader = R"Py( 30 # Autogenerated by mlir-tblgen; don't manually edit. 31 32 from enum import IntEnum, auto, IntFlag 33 from ._ods_common import _cext as _ods_cext 34 from ..ir import register_attribute_builder 35 _ods_ir = _ods_cext.ir 36 37 )Py"; 38 39 /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE. 40 static std::string makePythonEnumCaseName(StringRef name) { 41 if (isPythonReserved(name.str())) 42 return (name + "_").str(); 43 return name.str(); 44 } 45 46 /// Emits the Python class for the given enum. 47 static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) { 48 os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(), 49 enumAttr.isBitEnum() ? "IntFlag" : "IntEnum"); 50 if (!enumAttr.getSummary().empty()) 51 os << formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary()); 52 os << "\n"; 53 54 for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) { 55 os << formatv(" {0} = {1}\n", 56 makePythonEnumCaseName(enumCase.getSymbol()), 57 enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue()) 58 : "auto()"); 59 } 60 61 os << "\n"; 62 63 if (enumAttr.isBitEnum()) { 64 os << formatv(" def __iter__(self):\n" 65 " return iter([case for case in type(self) if " 66 "(self & case) is case])\n"); 67 os << formatv(" def __len__(self):\n" 68 " return bin(self).count(\"1\")\n"); 69 os << "\n"; 70 } 71 72 os << formatv(" def __str__(self):\n"); 73 if (enumAttr.isBitEnum()) 74 os << formatv(" if len(self) > 1:\n" 75 " return \"{0}\".join(map(str, self))\n", 76 enumAttr.getDef().getValueAsString("separator")); 77 for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) { 78 os << formatv(" if self is {0}.{1}:\n", enumAttr.getEnumClassName(), 79 makePythonEnumCaseName(enumCase.getSymbol())); 80 os << formatv(" return \"{0}\"\n", enumCase.getStr()); 81 } 82 os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n", 83 enumAttr.getEnumClassName()); 84 os << "\n"; 85 } 86 87 /// Attempts to extract the bitwidth B from string "uintB_t" describing the 88 /// type. This bitwidth information is not readily available in ODS. Returns 89 /// `false` on success, `true` on failure. 90 static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) { 91 if (!uintType.consume_front("uint")) 92 return true; 93 if (!uintType.consume_back("_t")) 94 return true; 95 return uintType.getAsInteger(/*Radix=*/10, bitwidth); 96 } 97 98 /// Emits an attribute builder for the given enum attribute to support automatic 99 /// conversion between enum values and attributes in Python. Returns 100 /// `false` on success, `true` on failure. 101 static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) { 102 int64_t bitwidth; 103 if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) { 104 llvm::errs() << "failed to identify bitwidth of " 105 << enumAttr.getUnderlyingType(); 106 return true; 107 } 108 109 os << formatv("@register_attribute_builder(\"{0}\")\n", 110 enumAttr.getAttrDefName()); 111 os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower()); 112 os << formatv(" return " 113 "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " 114 "context=context), int(x))\n\n", 115 bitwidth); 116 return false; 117 } 118 119 /// Emits an attribute builder for the given dialect enum attribute to support 120 /// automatic conversion between enum values and attributes in Python. Returns 121 /// `false` on success, `true` on failure. 122 static bool emitDialectEnumAttributeBuilder(StringRef attrDefName, 123 StringRef formatString, 124 raw_ostream &os) { 125 os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName); 126 os << formatv("def _{0}(x, context):\n", attrDefName.lower()); 127 os << formatv(" return " 128 "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n", 129 formatString); 130 return false; 131 } 132 133 /// Emits Python bindings for all enums in the record keeper. Returns 134 /// `false` on success, `true` on failure. 135 static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) { 136 os << fileHeader; 137 for (const Record *it : 138 records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) { 139 EnumAttr enumAttr(*it); 140 emitEnumClass(enumAttr, os); 141 emitAttributeBuilder(enumAttr, os); 142 } 143 for (const Record *it : 144 records.getAllDerivedDefinitionsIfDefined("EnumAttr")) { 145 AttrOrTypeDef attr(&*it); 146 if (!attr.getMnemonic()) { 147 llvm::errs() << "enum case " << attr 148 << " needs mnemonic for python enum bindings generation"; 149 return true; 150 } 151 StringRef mnemonic = attr.getMnemonic().value(); 152 std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat(); 153 StringRef dialect = attr.getDialect().getName(); 154 if (assemblyFormat == "`<` $value `>`") { 155 emitDialectEnumAttributeBuilder( 156 attr.getName(), 157 formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os); 158 } else if (assemblyFormat == "$value") { 159 emitDialectEnumAttributeBuilder( 160 attr.getName(), 161 formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os); 162 } else { 163 llvm::errs() 164 << "unsupported assembly format for python enum bindings generation"; 165 return true; 166 } 167 } 168 169 return false; 170 } 171 172 // Registers the enum utility generator to mlir-tblgen. 173 static mlir::GenRegistration 174 genPythonEnumBindings("gen-python-enum-bindings", 175 "Generate Python bindings for enum attributes", 176 &emitPythonEnums); 177