//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python // binding classes wrapping a generic operation API. // //===----------------------------------------------------------------------===// #include "OpGenHelpers.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" using namespace mlir; using namespace mlir::tblgen; using llvm::formatv; using llvm::Record; using llvm::RecordKeeper; /// File header and includes. /// {0} is the dialect namespace. constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. from ._ods_common import _cext as _ods_cext from ._ods_common import ( equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_op_results as _get_op_result_or_op_results, get_op_results_or_values as _get_op_results_or_values, segmented_accessor as _ods_segmented_accessor, ) _ods_ir = _ods_cext.ir import builtins from typing import Sequence as _Sequence, Union as _Union )Py"; /// Template for dialect class: /// {0} is the dialect namespace. constexpr const char *dialectClassTemplate = R"Py( @_ods_cext.register_dialect class _Dialect(_ods_ir.Dialect): DIALECT_NAMESPACE = "{0}" )Py"; constexpr const char *dialectExtensionTemplate = R"Py( from ._{0}_ops_gen import _Dialect )Py"; /// Template for operation class: /// {0} is the Python class name; /// {1} is the operation name. constexpr const char *opClassTemplate = R"Py( @_ods_cext.register_operation(_Dialect) class {0}(_ods_ir.OpView): OPERATION_NAME = "{1}" )Py"; /// Template for class level declarations of operand and result /// segment specs. /// {0} is either "OPERAND" or "RESULT" /// {1} is the segment spec /// Each segment spec is either None (default) or an array of integers /// where: /// 1 = single element (expect non sequence operand/result) /// 0 = optional element (expect a value or std::nullopt) /// -1 = operand/result is a sequence corresponding to a variadic constexpr const char *opClassSizedSegmentsTemplate = R"Py( _ODS_{0}_SEGMENTS = {1} )Py"; /// Template for class level declarations of the _ODS_REGIONS spec: /// {0} is the minimum number of regions /// {1} is the Python bool literal for hasNoVariadicRegions constexpr const char *opClassRegionSpecTemplate = R"Py( _ODS_REGIONS = ({0}, {1}) )Py"; /// Template for single-element accessor: /// {0} is the name of the accessor; /// {1} is either 'operand' or 'result'; /// {2} is the position in the element list. constexpr const char *opSingleTemplate = R"Py( @builtins.property def {0}(self): return self.operation.{1}s[{2}] )Py"; /// Template for single-element accessor after a variable-length group: /// {0} is the name of the accessor; /// {1} is either 'operand' or 'result'; /// {2} is the total number of element groups; /// {3} is the position of the current group in the group list. /// This works for both a single variadic group (non-negative length) and an /// single optional element (zero length if the element is absent). constexpr const char *opSingleAfterVariableTemplate = R"Py( @builtins.property def {0}(self): _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] )Py"; /// Template for an optional element accessor: /// {0} is the name of the accessor; /// {1} is either 'operand' or 'result'; /// {2} is the total number of element groups; /// {3} is the position of the current group in the group list. /// This works if we have only one variable-length group (and it's the optional /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is /// smaller than the total number of groups. constexpr const char *opOneOptionalTemplate = R"Py( @builtins.property def {0}(self): return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] )Py"; /// Template for the variadic group accessor in the single variadic group case: /// {0} is the name of the accessor; /// {1} is either 'operand' or 'result'; /// {2} is the total number of element groups; /// {3} is the position of the current group in the group list. constexpr const char *opOneVariadicTemplate = R"Py( @builtins.property def {0}(self): _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] )Py"; /// First part of the template for equally-sized variadic group accessor: /// {0} is the name of the accessor; /// {1} is either 'operand' or 'result'; /// {2} is the total number of non-variadic groups; /// {3} is the total number of variadic groups; /// {4} is the number of non-variadic groups preceding the current group; /// {5} is the number of variadic groups preceding the current group. constexpr const char *opVariadicEqualPrefixTemplate = R"Py( @builtins.property def {0}(self): start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py"; /// Second part of the template for equally-sized case, accessing a single /// element: /// {0} is either 'operand' or 'result'. constexpr const char *opVariadicEqualSimpleTemplate = R"Py( return self.operation.{0}s[start] )Py"; /// Second part of the template for equally-sized case, accessing a variadic /// group: /// {0} is either 'operand' or 'result'. constexpr const char *opVariadicEqualVariadicTemplate = R"Py( return self.operation.{0}s[start:start + elements_per_group] )Py"; /// Template for an attribute-sized group accessor: /// {0} is the name of the accessor; /// {1} is either 'operand' or 'result'; /// {2} is the position of the group in the group list; /// {3} is a return suffix (expected [0] for single-element, empty for /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). constexpr const char *opVariadicSegmentTemplate = R"Py( @builtins.property def {0}(self): {1}_range = _ods_segmented_accessor( self.operation.{1}s, self.operation.attributes["{1}SegmentSizes"], {2}) return {1}_range{3} )Py"; /// Template for a suffix when accessing an optional element in the /// attribute-sized case: /// {0} is either 'operand' or 'result'; constexpr const char *opVariadicSegmentOptionalTrailingTemplate = R"Py([0] if len({0}_range) > 0 else None)Py"; /// Template for an operation attribute getter: /// {0} is the name of the attribute sanitized for Python; /// {1} is the original name of the attribute. constexpr const char *attributeGetterTemplate = R"Py( @builtins.property def {0}(self): return self.operation.attributes["{1}"] )Py"; /// Template for an optional operation attribute getter: /// {0} is the name of the attribute sanitized for Python; /// {1} is the original name of the attribute. constexpr const char *optionalAttributeGetterTemplate = R"Py( @builtins.property def {0}(self): if "{1}" not in self.operation.attributes: return None return self.operation.attributes["{1}"] )Py"; /// Template for a getter of a unit operation attribute, returns True of the /// unit attribute is present, False otherwise (unit attributes have meaning /// by mere presence): /// {0} is the name of the attribute sanitized for Python, /// {1} is the original name of the attribute. constexpr const char *unitAttributeGetterTemplate = R"Py( @builtins.property def {0}(self): return "{1}" in self.operation.attributes )Py"; /// Template for an operation attribute setter: /// {0} is the name of the attribute sanitized for Python; /// {1} is the original name of the attribute. constexpr const char *attributeSetterTemplate = R"Py( @{0}.setter def {0}(self, value): if value is None: raise ValueError("'None' not allowed as value for mandatory attributes") self.operation.attributes["{1}"] = value )Py"; /// Template for a setter of an optional operation attribute, setting to None /// removes the attribute: /// {0} is the name of the attribute sanitized for Python; /// {1} is the original name of the attribute. constexpr const char *optionalAttributeSetterTemplate = R"Py( @{0}.setter def {0}(self, value): if value is not None: self.operation.attributes["{1}"] = value elif "{1}" in self.operation.attributes: del self.operation.attributes["{1}"] )Py"; /// Template for a setter of a unit operation attribute, setting to None or /// False removes the attribute: /// {0} is the name of the attribute sanitized for Python; /// {1} is the original name of the attribute. constexpr const char *unitAttributeSetterTemplate = R"Py( @{0}.setter def {0}(self, value): if bool(value): self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() elif "{1}" in self.operation.attributes: del self.operation.attributes["{1}"] )Py"; /// Template for a deleter of an optional or a unit operation attribute, removes /// the attribute from the operation: /// {0} is the name of the attribute sanitized for Python; /// {1} is the original name of the attribute. constexpr const char *attributeDeleterTemplate = R"Py( @{0}.deleter def {0}(self): del self.operation.attributes["{1}"] )Py"; constexpr const char *regionAccessorTemplate = R"Py( @builtins.property def {0}(self): return self.regions[{1}] )Py"; constexpr const char *valueBuilderTemplate = R"Py( def {0}({2}) -> {4}: return {1}({3}){5} )Py"; constexpr const char *valueBuilderVariadicTemplate = R"Py( def {0}({2}) -> {4}: return _get_op_result_or_op_results({1}({3})) )Py"; static llvm::cl::OptionCategory clOpPythonBindingCat("Options for -gen-python-op-bindings"); static llvm::cl::opt clDialectName("bind-dialect", llvm::cl::desc("The dialect to run the generator for"), llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); static llvm::cl::opt clDialectExtensionName( "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"), llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); using AttributeClasses = DenseMap; /// Checks whether `str` would shadow a generated variable or attribute /// part of the OpView API. static bool isODSReserved(StringRef str) { static llvm::StringSet<> reserved( {"attributes", "create", "context", "ip", "operands", "print", "get_asm", "loc", "verify", "regions", "results", "self", "operation", "DIALECT_NAMESPACE", "OPERATION_NAME"}); return str.starts_with("_ods_") || str.ends_with("_ods") || reserved.contains(str); } /// Modifies the `name` in a way that it becomes suitable for Python bindings /// (does not change the `name` if it already is suitable) and returns the /// modified version. static std::string sanitizeName(StringRef name) { std::string processedStr = name.str(); std::replace_if( processedStr.begin(), processedStr.end(), [](char c) { return !llvm::isAlnum(c); }, '_'); if (llvm::isDigit(*processedStr.begin())) return "_" + processedStr; if (isPythonReserved(processedStr) || isODSReserved(processedStr)) return processedStr + "_"; return processedStr; } static std::string attrSizedTraitForKind(const char *kind) { return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", StringRef(kind).take_front().upper(), StringRef(kind).drop_front()); } /// Emits accessors to "elements" of an Op definition. Currently, the supported /// elements are operands and results, indicated by `kind`, which must be either /// `operand` or `result` and is used verbatim in the emitted code. static void emitElementAccessors( const Operator &op, raw_ostream &os, const char *kind, unsigned numVariadicGroups, unsigned numElements, llvm::function_ref getElement) { assert(llvm::is_contained(SmallVector{"operand", "result"}, kind) && "unsupported kind"); // Traits indicating how to process variadic elements. std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", StringRef(kind).take_front().upper(), StringRef(kind).drop_front()); std::string attrSizedTrait = attrSizedTraitForKind(kind); // If there is only one variable-length element group, its size can be // inferred from the total number of elements. If there are none, the // generation is straightforward. if (numVariadicGroups <= 1) { bool seenVariableLength = false; for (unsigned i = 0; i < numElements; ++i) { const NamedTypeConstraint &element = getElement(op, i); if (element.isVariableLength()) seenVariableLength = true; if (element.name.empty()) continue; if (element.isVariableLength()) { os << formatv(element.isOptional() ? opOneOptionalTemplate : opOneVariadicTemplate, sanitizeName(element.name), kind, numElements, i); } else if (seenVariableLength) { os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name), kind, numElements, i); } else { os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i); } } return; } // Handle the operations where variadic groups have the same size. if (op.getTrait(sameSizeTrait)) { // Count the number of simple elements unsigned numSimpleLength = 0; for (unsigned i = 0; i < numElements; ++i) { const NamedTypeConstraint &element = getElement(op, i); if (!element.isVariableLength()) { ++numSimpleLength; } } // Generate the accessors int numPrecedingSimple = 0; int numPrecedingVariadic = 0; for (unsigned i = 0; i < numElements; ++i) { const NamedTypeConstraint &element = getElement(op, i); if (!element.name.empty()) { os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name), kind, numSimpleLength, numVariadicGroups, numPrecedingSimple, numPrecedingVariadic); os << formatv(element.isVariableLength() ? opVariadicEqualVariadicTemplate : opVariadicEqualSimpleTemplate, kind); } if (element.isVariableLength()) ++numPrecedingVariadic; else ++numPrecedingSimple; } return; } // Handle the operations where the size of groups (variadic or not) is // provided as an attribute. For non-variadic elements, make sure to return // an element rather than a singleton container. if (op.getTrait(attrSizedTrait)) { for (unsigned i = 0; i < numElements; ++i) { const NamedTypeConstraint &element = getElement(op, i); if (element.name.empty()) continue; std::string trailing; if (!element.isVariableLength()) trailing = "[0]"; else if (element.isOptional()) trailing = std::string( formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind, i, trailing); } return; } llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); } /// Free function helpers accessing Operator components. static int getNumOperands(const Operator &op) { return op.getNumOperands(); } static const NamedTypeConstraint &getOperand(const Operator &op, int i) { return op.getOperand(i); } static int getNumResults(const Operator &op) { return op.getNumResults(); } static const NamedTypeConstraint &getResult(const Operator &op, int i) { return op.getResult(i); } /// Emits accessors to Op operands. static void emitOperandAccessors(const Operator &op, raw_ostream &os) { emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(), getNumOperands(op), getOperand); } /// Emits accessors Op results. static void emitResultAccessors(const Operator &op, raw_ostream &os) { emitElementAccessors(op, os, "result", op.getNumVariableLengthResults(), getNumResults(op), getResult); } /// Emits accessors to Op attributes. static void emitAttributeAccessors(const Operator &op, raw_ostream &os) { for (const auto &namedAttr : op.getAttributes()) { // Skip "derived" attributes because they are just C++ functions that we // don't currently expose. if (namedAttr.attr.isDerivedAttr()) continue; if (namedAttr.name.empty()) continue; std::string sanitizedName = sanitizeName(namedAttr.name); // Unit attributes are handled specially. if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") { os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name); os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name); os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name); continue; } if (namedAttr.attr.isOptional()) { os << formatv(optionalAttributeGetterTemplate, sanitizedName, namedAttr.name); os << formatv(optionalAttributeSetterTemplate, sanitizedName, namedAttr.name); os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name); } else { os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name); os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name); // Non-optional attributes cannot be deleted. } } } /// Template for the default auto-generated builder. /// {0} is a comma-separated list of builder arguments, including the trailing /// `loc` and `ip`; /// {1} is the code populating `operands`, `results` and `attributes`, /// `successors` fields. constexpr const char *initTemplate = R"Py( def __init__(self, {0}): operands = [] results = [] attributes = {{} regions = None {1} super().__init__({2}) )Py"; /// Template for appending a single element to the operand/result list. /// {0} is the field name. constexpr const char *singleOperandAppendTemplate = "operands.append({0})"; constexpr const char *singleResultAppendTemplate = "results.append({0})"; /// Template for appending an optional element to the operand/result list. /// {0} is the field name. constexpr const char *optionalAppendOperandTemplate = "if {0} is not None: operands.append({0})"; constexpr const char *optionalAppendAttrSizedOperandsTemplate = "operands.append({0})"; constexpr const char *optionalAppendResultTemplate = "if {0} is not None: results.append({0})"; /// Template for appending a list of elements to the operand/result list. /// {0} is the field name. constexpr const char *multiOperandAppendTemplate = "operands.extend(_get_op_results_or_values({0}))"; constexpr const char *multiOperandAppendPackTemplate = "operands.append(_get_op_results_or_values({0}))"; constexpr const char *multiResultAppendTemplate = "results.extend({0})"; /// Template for attribute builder from raw input in the operation builder. /// {0} is the builder argument name; /// {1} is the attribute builder from raw; /// {2} is the attribute builder from raw. /// Use the value the user passed in if either it is already an Attribute or /// there is no method registered to make it an Attribute. constexpr const char *initAttributeWithBuilderTemplate = R"Py(attributes["{1}"] = ({0} if ( isinstance({0}, _ods_ir.Attribute) or not _ods_ir.AttrBuilder.contains('{2}')) else _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py"; /// Template for attribute builder from raw input for optional attribute in the /// operation builder. /// {0} is the builder argument name; /// {1} is the attribute builder from raw; /// {2} is the attribute builder from raw. /// Use the value the user passed in if either it is already an Attribute or /// there is no method registered to make it an Attribute. constexpr const char *initOptionalAttributeWithBuilderTemplate = R"Py(if {0} is not None: attributes["{1}"] = ({0} if ( isinstance({0}, _ods_ir.Attribute) or not _ods_ir.AttrBuilder.contains('{2}')) else _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py"; constexpr const char *initUnitAttributeTemplate = R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( _ods_get_default_loc_context(loc)))Py"; /// Template to initialize the successors list in the builder if there are any /// successors. /// {0} is the value to initialize the successors list to. constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; /// Template to append or extend the list of successors in the builder. /// {0} is the list method ('append' or 'extend'); /// {1} is the value to add. constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; /// Returns true if the SameArgumentAndResultTypes trait can be used to infer /// result types of the given operation. static bool hasSameArgumentAndResultTypes(const Operator &op) { return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && op.getNumVariableLengthResults() == 0; } /// Returns true if the FirstAttrDerivedResultType trait can be used to infer /// result types of the given operation. static bool hasFirstAttrDerivedResultTypes(const Operator &op) { return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && op.getNumVariableLengthResults() == 0; } /// Returns true if the InferTypeOpInterface can be used to infer result types /// of the given operation. static bool hasInferTypeInterface(const Operator &op) { return op.getTrait("::mlir::InferTypeOpInterface::Trait") && op.getNumRegions() == 0; } /// Returns true if there is a trait or interface that can be used to infer /// result types of the given operation. static bool canInferType(const Operator &op) { return hasSameArgumentAndResultTypes(op) || hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); } /// Populates `builderArgs` with result names if the builder is expected to /// accept them as arguments. static void populateBuilderArgsResults(const Operator &op, SmallVectorImpl &builderArgs) { if (canInferType(op)) return; for (int i = 0, e = op.getNumResults(); i < e; ++i) { std::string name = op.getResultName(i).str(); if (name.empty()) { if (op.getNumResults() == 1) { // Special case for one result, make the default name be 'result' // to properly match the built-in result accessor. name = "result"; } else { name = formatv("_gen_res_{0}", i); } } name = sanitizeName(name); builderArgs.push_back(name); } } /// Populates `builderArgs` with the Python-compatible names of builder function /// arguments using intermixed attributes and operands in the same order as they /// appear in the `arguments` field of the op definition. Additionally, /// `operandNames` is populated with names of operands in their order of /// appearance. static void populateBuilderArgs(const Operator &op, SmallVectorImpl &builderArgs, SmallVectorImpl &operandNames) { for (int i = 0, e = op.getNumArgs(); i < e; ++i) { std::string name = op.getArgName(i).str(); if (name.empty()) name = formatv("_gen_arg_{0}", i); name = sanitizeName(name); builderArgs.push_back(name); if (!isa(op.getArg(i))) operandNames.push_back(name); } } /// Populates `builderArgs` with the Python-compatible names of builder function /// successor arguments. Additionally, `successorArgNames` is also populated. static void populateBuilderArgsSuccessors(const Operator &op, SmallVectorImpl &builderArgs, SmallVectorImpl &successorArgNames) { for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { NamedSuccessor successor = op.getSuccessor(i); std::string name = std::string(successor.name); if (name.empty()) name = formatv("_gen_successor_{0}", i); name = sanitizeName(name); builderArgs.push_back(name); successorArgNames.push_back(name); } } /// Populates `builderLines` with additional lines that are required in the /// builder to set up operation attributes. `argNames` is expected to contain /// the names of builder arguments that correspond to op arguments, i.e. to the /// operands and attributes in the same order as they appear in the `arguments` /// field. static void populateBuilderLinesAttr(const Operator &op, ArrayRef argNames, SmallVectorImpl &builderLines) { builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)"); for (int i = 0, e = op.getNumArgs(); i < e; ++i) { Argument arg = op.getArg(i); auto *attribute = llvm::dyn_cast_if_present(arg); if (!attribute) continue; // Unit attributes are handled specially. if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") { builderLines.push_back( formatv(initUnitAttributeTemplate, attribute->name, argNames[i])); continue; } builderLines.push_back(formatv( attribute->attr.isOptional() || attribute->attr.hasDefaultValue() ? initOptionalAttributeWithBuilderTemplate : initAttributeWithBuilderTemplate, argNames[i], attribute->name, attribute->attr.getAttrDefName())); } } /// Populates `builderLines` with additional lines that are required in the /// builder to set up successors. successorArgNames is expected to correspond /// to the Python argument name for each successor on the op. static void populateBuilderLinesSuccessors(const Operator &op, ArrayRef successorArgNames, SmallVectorImpl &builderLines) { if (successorArgNames.empty()) { builderLines.push_back(formatv(initSuccessorsTemplate, "None")); return; } builderLines.push_back(formatv(initSuccessorsTemplate, "[]")); for (int i = 0, e = successorArgNames.size(); i < e; ++i) { auto &argName = successorArgNames[i]; const NamedSuccessor &successor = op.getSuccessor(i); builderLines.push_back(formatv(addSuccessorTemplate, successor.isVariadic() ? "extend" : "append", argName)); } } /// Populates `builderLines` with additional lines that are required in the /// builder to set up op operands. static void populateBuilderLinesOperand(const Operator &op, ArrayRef names, SmallVectorImpl &builderLines) { bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; // For each element, find or generate a name. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { const NamedTypeConstraint &element = op.getOperand(i); std::string name = names[i]; // Choose the formatting string based on the element kind. StringRef formatString; if (!element.isVariableLength()) { formatString = singleOperandAppendTemplate; } else if (element.isOptional()) { if (sizedSegments) { formatString = optionalAppendAttrSizedOperandsTemplate; } else { formatString = optionalAppendOperandTemplate; } } else { assert(element.isVariadic() && "unhandled element group type"); // If emitting with sizedSegments, then we add the actual list-typed // element. Otherwise, we extend the actual operands. if (sizedSegments) { formatString = multiOperandAppendPackTemplate; } else { formatString = multiOperandAppendTemplate; } } builderLines.push_back(formatv(formatString.data(), name)); } } /// Python code template for deriving the operation result types from its /// attribute: /// - {0} is the name of the attribute from which to derive the types. constexpr const char *deriveTypeFromAttrTemplate = R"Py(_ods_result_type_source_attr = attributes["{0}"] _ods_derived_result_type = ( _ods_ir.TypeAttr(_ods_result_type_source_attr).value if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else _ods_result_type_source_attr.type))Py"; /// Python code template appending {0} type {1} times to the results list. constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; /// Appends the given multiline string as individual strings into /// `builderLines`. static void appendLineByLine(StringRef string, SmallVectorImpl &builderLines) { std::pair split = std::make_pair(string, string); do { split = split.second.split('\n'); builderLines.push_back(split.first.str()); } while (!split.second.empty()); } /// Populates `builderLines` with additional lines that are required in the /// builder to set up op results. static void populateBuilderLinesResult(const Operator &op, ArrayRef names, SmallVectorImpl &builderLines) { bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; if (hasSameArgumentAndResultTypes(op)) { builderLines.push_back(formatv(appendSameResultsTemplate, "operands[0].type", op.getNumResults())); return; } if (hasFirstAttrDerivedResultTypes(op)) { const NamedAttribute &firstAttr = op.getAttribute(0); assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " "from which the type is derived"); appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), builderLines); builderLines.push_back(formatv(appendSameResultsTemplate, "_ods_derived_result_type", op.getNumResults())); return; } if (hasInferTypeInterface(op)) return; // For each element, find or generate a name. for (int i = 0, e = op.getNumResults(); i < e; ++i) { const NamedTypeConstraint &element = op.getResult(i); std::string name = names[i]; // Choose the formatting string based on the element kind. StringRef formatString; if (!element.isVariableLength()) { formatString = singleResultAppendTemplate; } else if (element.isOptional()) { formatString = optionalAppendResultTemplate; } else { assert(element.isVariadic() && "unhandled element group type"); // If emitting with sizedSegments, then we add the actual list-typed // element. Otherwise, we extend the actual operands. if (sizedSegments) { formatString = singleResultAppendTemplate; } else { formatString = multiResultAppendTemplate; } } builderLines.push_back(formatv(formatString.data(), name)); } } /// If the operation has variadic regions, adds a builder argument to specify /// the number of those regions and builder lines to forward it to the generic /// constructor. static void populateBuilderRegions(const Operator &op, SmallVectorImpl &builderArgs, SmallVectorImpl &builderLines) { if (op.hasNoVariadicRegions()) return; // This is currently enforced when Operator is constructed. assert(op.getNumVariadicRegions() == 1 && op.getRegion(op.getNumRegions() - 1).isVariadic() && "expected the last region to be varidic"); const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1); std::string name = ("num_" + region.name.take_front().lower() + region.name.drop_front()) .str(); builderArgs.push_back(name); builderLines.push_back( formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); } /// Emits a default builder constructing an operation from the list of its /// result types, followed by a list of its operands. Returns vector /// of fully built functionArgs for downstream users (to save having to /// rebuild anew). static SmallVector emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { SmallVector builderArgs; SmallVector builderLines; SmallVector operandArgNames; SmallVector successorArgNames; builderArgs.reserve(op.getNumOperands() + op.getNumResults() + op.getNumNativeAttributes() + op.getNumSuccessors()); populateBuilderArgsResults(op, builderArgs); size_t numResultArgs = builderArgs.size(); populateBuilderArgs(op, builderArgs, operandArgNames); size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); populateBuilderLinesOperand(op, operandArgNames, builderLines); populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs), builderLines); populateBuilderLinesResult( op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines); populateBuilderLinesSuccessors(op, successorArgNames, builderLines); populateBuilderRegions(op, builderArgs, builderLines); // Layout of builderArgs vector elements: // [ result_args operand_attr_args successor_args regions ] // Determine whether the argument corresponding to a given index into the // builderArgs vector is a python keyword argument or not. auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { // All result, successor, and region arguments are positional arguments. if ((builderArgIndex < numResultArgs) || (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) return false; // Keyword arguments: // - optional named attributes (including unit attributes) // - default-valued named attributes // - optional operands Argument a = op.getArg(builderArgIndex - numResultArgs); if (auto *nattr = llvm::dyn_cast_if_present(a)) return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); if (auto *ntype = llvm::dyn_cast_if_present(a)) return ntype->isOptional(); return false; }; // StringRefs in functionArgs refer to strings allocated by builderArgs. SmallVector functionArgs; // Add positional arguments. for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { if (!isKeywordArgFn(i)) functionArgs.push_back(builderArgs[i]); } // Add a bare '*' to indicate that all following arguments must be keyword // arguments. functionArgs.push_back("*"); // Add a default 'None' value to each keyword arg string, and then add to the // function args list. for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { if (isKeywordArgFn(i)) { builderArgs[i].append("=None"); functionArgs.push_back(builderArgs[i]); } } functionArgs.push_back("loc=None"); functionArgs.push_back("ip=None"); SmallVector initArgs; initArgs.push_back("self.OPERATION_NAME"); initArgs.push_back("self._ODS_REGIONS"); initArgs.push_back("self._ODS_OPERAND_SEGMENTS"); initArgs.push_back("self._ODS_RESULT_SEGMENTS"); initArgs.push_back("attributes=attributes"); if (!hasInferTypeInterface(op)) initArgs.push_back("results=results"); initArgs.push_back("operands=operands"); initArgs.push_back("successors=_ods_successors"); initArgs.push_back("regions=regions"); initArgs.push_back("loc=loc"); initArgs.push_back("ip=ip"); os << formatv(initTemplate, llvm::join(functionArgs, ", "), llvm::join(builderLines, "\n "), llvm::join(initArgs, ", ")); return llvm::to_vector<8>( llvm::map_range(functionArgs, [](StringRef s) { return s.str(); })); } static void emitSegmentSpec( const Operator &op, const char *kind, llvm::function_ref getNumElements, llvm::function_ref getElement, raw_ostream &os) { std::string segmentSpec("["); for (int i = 0, e = getNumElements(op); i < e; ++i) { const NamedTypeConstraint &element = getElement(op, i); if (element.isOptional()) { segmentSpec.append("0,"); } else if (element.isVariadic()) { segmentSpec.append("-1,"); } else { segmentSpec.append("1,"); } } segmentSpec.append("]"); os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); } static void emitRegionAttributes(const Operator &op, raw_ostream &os) { // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). // Note that the base OpView class defines this as (0, True). unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); os << formatv(opClassRegionSpecTemplate, minRegionCount, op.hasNoVariadicRegions() ? "True" : "False"); } /// Emits named accessors to regions. static void emitRegionAccessors(const Operator &op, raw_ostream &os) { for (const auto &en : llvm::enumerate(op.getRegions())) { const NamedRegion ®ion = en.value(); if (region.name.empty()) continue; assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && "expected only the last region to be variadic"); os << formatv(regionAccessorTemplate, sanitizeName(region.name), std::to_string(en.index()) + (region.isVariadic() ? ":" : "")); } } /// Emits builder that extracts results from op static void emitValueBuilder(const Operator &op, SmallVector functionArgs, raw_ostream &os) { // Params with (possibly) default args. auto valueBuilderParams = llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) { SmallVector argMaybeDefault = llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "=")); auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]); if (argMaybeDefault.size() == 2) return arg + "=" + argMaybeDefault[1].str(); return arg; }); // Actual args passed to op builder (e.g., opParam=op_param). auto opBuilderArgs = llvm::map_range( llvm::make_filter_range(functionArgs, [](const std::string &s) { return s != "*"; }), [](const std::string &arg) { auto lhs = *llvm::split(arg, "=").begin(); return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str(); }); std::string nameWithoutDialect = sanitizeName( op.getOperationName().substr(op.getOperationName().find('.') + 1)); std::string params = llvm::join(valueBuilderParams, ", "); std::string args = llvm::join(opBuilderArgs, ", "); const char *type = (op.getNumResults() > 1 ? "_Sequence[_ods_ir.Value]" : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")); if (op.getNumVariableLengthResults() > 0) { os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect, op.getCppClassName(), params, args, type); } else { const char *results; if (op.getNumResults() == 0) { results = ""; } else if (op.getNumResults() == 1) { results = ".result"; } else { results = ".results"; } os << formatv(valueBuilderTemplate, nameWithoutDialect, op.getCppClassName(), params, args, type, results); } } /// Emits bindings for a specific Op to the given output stream. static void emitOpBindings(const Operator &op, raw_ostream &os) { os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName()); // Sized segments. if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); } if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); } emitRegionAttributes(op, os); SmallVector functionArgs = emitDefaultOpBuilder(op, os); emitOperandAccessors(op, os); emitAttributeAccessors(op, os); emitResultAccessors(op, os); emitRegionAccessors(op, os); emitValueBuilder(op, functionArgs, os); } /// Emits bindings for the dialect specified in the command line, including file /// headers and utilities. Returns `false` on success to comply with Tablegen /// registration requirements. static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) { if (clDialectName.empty()) llvm::PrintFatalError("dialect name not provided"); os << fileHeader; if (!clDialectExtensionName.empty()) os << formatv(dialectExtensionTemplate, clDialectName.getValue()); else os << formatv(dialectClassTemplate, clDialectName.getValue()); for (const Record *rec : records.getAllDerivedDefinitions("Op")) { Operator op(rec); if (op.getDialectName() == clDialectName.getValue()) emitOpBindings(op, os); } return false; } static GenRegistration genPythonBindings("gen-python-op-bindings", "Generate Python bindings for MLIR Ops", &emitAllOps);