1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "OpGenHelpers.h" 15 16 #include "mlir/TableGen/GenInfo.h" 17 #include "mlir/TableGen/Operator.h" 18 #include "llvm/ADT/StringSet.h" 19 #include "llvm/Support/CommandLine.h" 20 #include "llvm/Support/FormatVariadic.h" 21 #include "llvm/TableGen/Error.h" 22 #include "llvm/TableGen/Record.h" 23 24 using namespace mlir; 25 using namespace mlir::tblgen; 26 using llvm::formatv; 27 using llvm::Record; 28 using llvm::RecordKeeper; 29 30 /// File header and includes. 31 /// {0} is the dialect namespace. 32 constexpr const char *fileHeader = R"Py( 33 # Autogenerated by mlir-tblgen; don't manually edit. 34 35 from ._ods_common import _cext as _ods_cext 36 from ._ods_common import ( 37 equally_sized_accessor as _ods_equally_sized_accessor, 38 get_default_loc_context as _ods_get_default_loc_context, 39 get_op_result_or_op_results as _get_op_result_or_op_results, 40 get_op_results_or_values as _get_op_results_or_values, 41 segmented_accessor as _ods_segmented_accessor, 42 ) 43 _ods_ir = _ods_cext.ir 44 45 import builtins 46 from typing import Sequence as _Sequence, Union as _Union 47 48 )Py"; 49 50 /// Template for dialect class: 51 /// {0} is the dialect namespace. 52 constexpr const char *dialectClassTemplate = R"Py( 53 @_ods_cext.register_dialect 54 class _Dialect(_ods_ir.Dialect): 55 DIALECT_NAMESPACE = "{0}" 56 )Py"; 57 58 constexpr const char *dialectExtensionTemplate = R"Py( 59 from ._{0}_ops_gen import _Dialect 60 )Py"; 61 62 /// Template for operation class: 63 /// {0} is the Python class name; 64 /// {1} is the operation name. 65 constexpr const char *opClassTemplate = R"Py( 66 @_ods_cext.register_operation(_Dialect) 67 class {0}(_ods_ir.OpView): 68 OPERATION_NAME = "{1}" 69 )Py"; 70 71 /// Template for class level declarations of operand and result 72 /// segment specs. 73 /// {0} is either "OPERAND" or "RESULT" 74 /// {1} is the segment spec 75 /// Each segment spec is either None (default) or an array of integers 76 /// where: 77 /// 1 = single element (expect non sequence operand/result) 78 /// 0 = optional element (expect a value or std::nullopt) 79 /// -1 = operand/result is a sequence corresponding to a variadic 80 constexpr const char *opClassSizedSegmentsTemplate = R"Py( 81 _ODS_{0}_SEGMENTS = {1} 82 )Py"; 83 84 /// Template for class level declarations of the _ODS_REGIONS spec: 85 /// {0} is the minimum number of regions 86 /// {1} is the Python bool literal for hasNoVariadicRegions 87 constexpr const char *opClassRegionSpecTemplate = R"Py( 88 _ODS_REGIONS = ({0}, {1}) 89 )Py"; 90 91 /// Template for single-element accessor: 92 /// {0} is the name of the accessor; 93 /// {1} is either 'operand' or 'result'; 94 /// {2} is the position in the element list. 95 constexpr const char *opSingleTemplate = R"Py( 96 @builtins.property 97 def {0}(self): 98 return self.operation.{1}s[{2}] 99 )Py"; 100 101 /// Template for single-element accessor after a variable-length group: 102 /// {0} is the name of the accessor; 103 /// {1} is either 'operand' or 'result'; 104 /// {2} is the total number of element groups; 105 /// {3} is the position of the current group in the group list. 106 /// This works for both a single variadic group (non-negative length) and an 107 /// single optional element (zero length if the element is absent). 108 constexpr const char *opSingleAfterVariableTemplate = R"Py( 109 @builtins.property 110 def {0}(self): 111 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 112 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] 113 )Py"; 114 115 /// Template for an optional element accessor: 116 /// {0} is the name of the accessor; 117 /// {1} is either 'operand' or 'result'; 118 /// {2} is the total number of element groups; 119 /// {3} is the position of the current group in the group list. 120 /// This works if we have only one variable-length group (and it's the optional 121 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is 122 /// smaller than the total number of groups. 123 constexpr const char *opOneOptionalTemplate = R"Py( 124 @builtins.property 125 def {0}(self): 126 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] 127 )Py"; 128 129 /// Template for the variadic group accessor in the single variadic group case: 130 /// {0} is the name of the accessor; 131 /// {1} is either 'operand' or 'result'; 132 /// {2} is the total number of element groups; 133 /// {3} is the position of the current group in the group list. 134 constexpr const char *opOneVariadicTemplate = R"Py( 135 @builtins.property 136 def {0}(self): 137 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 138 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] 139 )Py"; 140 141 /// First part of the template for equally-sized variadic group accessor: 142 /// {0} is the name of the accessor; 143 /// {1} is either 'operand' or 'result'; 144 /// {2} is the total number of non-variadic groups; 145 /// {3} is the total number of variadic groups; 146 /// {4} is the number of non-variadic groups preceding the current group; 147 /// {5} is the number of variadic groups preceding the current group. 148 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 149 @builtins.property 150 def {0}(self): 151 start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py"; 152 153 /// Second part of the template for equally-sized case, accessing a single 154 /// element: 155 /// {0} is either 'operand' or 'result'. 156 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 157 return self.operation.{0}s[start] 158 )Py"; 159 160 /// Second part of the template for equally-sized case, accessing a variadic 161 /// group: 162 /// {0} is either 'operand' or 'result'. 163 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 164 return self.operation.{0}s[start:start + elements_per_group] 165 )Py"; 166 167 /// Template for an attribute-sized group accessor: 168 /// {0} is the name of the accessor; 169 /// {1} is either 'operand' or 'result'; 170 /// {2} is the position of the group in the group list; 171 /// {3} is a return suffix (expected [0] for single-element, empty for 172 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 173 constexpr const char *opVariadicSegmentTemplate = R"Py( 174 @builtins.property 175 def {0}(self): 176 {1}_range = _ods_segmented_accessor( 177 self.operation.{1}s, 178 self.operation.attributes["{1}SegmentSizes"], {2}) 179 return {1}_range{3} 180 )Py"; 181 182 /// Template for a suffix when accessing an optional element in the 183 /// attribute-sized case: 184 /// {0} is either 'operand' or 'result'; 185 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 186 R"Py([0] if len({0}_range) > 0 else None)Py"; 187 188 /// Template for an operation attribute getter: 189 /// {0} is the name of the attribute sanitized for Python; 190 /// {1} is the original name of the attribute. 191 constexpr const char *attributeGetterTemplate = R"Py( 192 @builtins.property 193 def {0}(self): 194 return self.operation.attributes["{1}"] 195 )Py"; 196 197 /// Template for an optional operation attribute getter: 198 /// {0} is the name of the attribute sanitized for Python; 199 /// {1} is the original name of the attribute. 200 constexpr const char *optionalAttributeGetterTemplate = R"Py( 201 @builtins.property 202 def {0}(self): 203 if "{1}" not in self.operation.attributes: 204 return None 205 return self.operation.attributes["{1}"] 206 )Py"; 207 208 /// Template for a getter of a unit operation attribute, returns True of the 209 /// unit attribute is present, False otherwise (unit attributes have meaning 210 /// by mere presence): 211 /// {0} is the name of the attribute sanitized for Python, 212 /// {1} is the original name of the attribute. 213 constexpr const char *unitAttributeGetterTemplate = R"Py( 214 @builtins.property 215 def {0}(self): 216 return "{1}" in self.operation.attributes 217 )Py"; 218 219 /// Template for an operation attribute setter: 220 /// {0} is the name of the attribute sanitized for Python; 221 /// {1} is the original name of the attribute. 222 constexpr const char *attributeSetterTemplate = R"Py( 223 @{0}.setter 224 def {0}(self, value): 225 if value is None: 226 raise ValueError("'None' not allowed as value for mandatory attributes") 227 self.operation.attributes["{1}"] = value 228 )Py"; 229 230 /// Template for a setter of an optional operation attribute, setting to None 231 /// removes the attribute: 232 /// {0} is the name of the attribute sanitized for Python; 233 /// {1} is the original name of the attribute. 234 constexpr const char *optionalAttributeSetterTemplate = R"Py( 235 @{0}.setter 236 def {0}(self, value): 237 if value is not None: 238 self.operation.attributes["{1}"] = value 239 elif "{1}" in self.operation.attributes: 240 del self.operation.attributes["{1}"] 241 )Py"; 242 243 /// Template for a setter of a unit operation attribute, setting to None or 244 /// False removes the attribute: 245 /// {0} is the name of the attribute sanitized for Python; 246 /// {1} is the original name of the attribute. 247 constexpr const char *unitAttributeSetterTemplate = R"Py( 248 @{0}.setter 249 def {0}(self, value): 250 if bool(value): 251 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() 252 elif "{1}" in self.operation.attributes: 253 del self.operation.attributes["{1}"] 254 )Py"; 255 256 /// Template for a deleter of an optional or a unit operation attribute, removes 257 /// the attribute from the operation: 258 /// {0} is the name of the attribute sanitized for Python; 259 /// {1} is the original name of the attribute. 260 constexpr const char *attributeDeleterTemplate = R"Py( 261 @{0}.deleter 262 def {0}(self): 263 del self.operation.attributes["{1}"] 264 )Py"; 265 266 constexpr const char *regionAccessorTemplate = R"Py( 267 @builtins.property 268 def {0}(self): 269 return self.regions[{1}] 270 )Py"; 271 272 constexpr const char *valueBuilderTemplate = R"Py( 273 def {0}({2}) -> {4}: 274 return {1}({3}){5} 275 )Py"; 276 277 constexpr const char *valueBuilderVariadicTemplate = R"Py( 278 def {0}({2}) -> {4}: 279 return _get_op_result_or_op_results({1}({3})) 280 )Py"; 281 282 static llvm::cl::OptionCategory 283 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 284 285 static llvm::cl::opt<std::string> 286 clDialectName("bind-dialect", 287 llvm::cl::desc("The dialect to run the generator for"), 288 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 289 290 static llvm::cl::opt<std::string> clDialectExtensionName( 291 "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"), 292 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 293 294 using AttributeClasses = DenseMap<StringRef, StringRef>; 295 296 /// Checks whether `str` would shadow a generated variable or attribute 297 /// part of the OpView API. 298 static bool isODSReserved(StringRef str) { 299 static llvm::StringSet<> reserved( 300 {"attributes", "create", "context", "ip", "operands", "print", "get_asm", 301 "loc", "verify", "regions", "results", "self", "operation", 302 "DIALECT_NAMESPACE", "OPERATION_NAME"}); 303 return str.starts_with("_ods_") || str.ends_with("_ods") || 304 reserved.contains(str); 305 } 306 307 /// Modifies the `name` in a way that it becomes suitable for Python bindings 308 /// (does not change the `name` if it already is suitable) and returns the 309 /// modified version. 310 static std::string sanitizeName(StringRef name) { 311 std::string processedStr = name.str(); 312 std::replace_if( 313 processedStr.begin(), processedStr.end(), 314 [](char c) { return !llvm::isAlnum(c); }, '_'); 315 316 if (llvm::isDigit(*processedStr.begin())) 317 return "_" + processedStr; 318 319 if (isPythonReserved(processedStr) || isODSReserved(processedStr)) 320 return processedStr + "_"; 321 return processedStr; 322 } 323 324 static std::string attrSizedTraitForKind(const char *kind) { 325 return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 326 StringRef(kind).take_front().upper(), 327 StringRef(kind).drop_front()); 328 } 329 330 /// Emits accessors to "elements" of an Op definition. Currently, the supported 331 /// elements are operands and results, indicated by `kind`, which must be either 332 /// `operand` or `result` and is used verbatim in the emitted code. 333 static void emitElementAccessors( 334 const Operator &op, raw_ostream &os, const char *kind, 335 unsigned numVariadicGroups, unsigned numElements, 336 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 337 getElement) { 338 assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"}, 339 kind) && 340 "unsupported kind"); 341 342 // Traits indicating how to process variadic elements. 343 std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 344 StringRef(kind).take_front().upper(), 345 StringRef(kind).drop_front()); 346 std::string attrSizedTrait = attrSizedTraitForKind(kind); 347 348 // If there is only one variable-length element group, its size can be 349 // inferred from the total number of elements. If there are none, the 350 // generation is straightforward. 351 if (numVariadicGroups <= 1) { 352 bool seenVariableLength = false; 353 for (unsigned i = 0; i < numElements; ++i) { 354 const NamedTypeConstraint &element = getElement(op, i); 355 if (element.isVariableLength()) 356 seenVariableLength = true; 357 if (element.name.empty()) 358 continue; 359 if (element.isVariableLength()) { 360 os << formatv(element.isOptional() ? opOneOptionalTemplate 361 : opOneVariadicTemplate, 362 sanitizeName(element.name), kind, numElements, i); 363 } else if (seenVariableLength) { 364 os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name), 365 kind, numElements, i); 366 } else { 367 os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i); 368 } 369 } 370 return; 371 } 372 373 // Handle the operations where variadic groups have the same size. 374 if (op.getTrait(sameSizeTrait)) { 375 // Count the number of simple elements 376 unsigned numSimpleLength = 0; 377 for (unsigned i = 0; i < numElements; ++i) { 378 const NamedTypeConstraint &element = getElement(op, i); 379 if (!element.isVariableLength()) { 380 ++numSimpleLength; 381 } 382 } 383 384 // Generate the accessors 385 int numPrecedingSimple = 0; 386 int numPrecedingVariadic = 0; 387 for (unsigned i = 0; i < numElements; ++i) { 388 const NamedTypeConstraint &element = getElement(op, i); 389 if (!element.name.empty()) { 390 os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name), 391 kind, numSimpleLength, numVariadicGroups, 392 numPrecedingSimple, numPrecedingVariadic); 393 os << formatv(element.isVariableLength() 394 ? opVariadicEqualVariadicTemplate 395 : opVariadicEqualSimpleTemplate, 396 kind); 397 } 398 if (element.isVariableLength()) 399 ++numPrecedingVariadic; 400 else 401 ++numPrecedingSimple; 402 } 403 return; 404 } 405 406 // Handle the operations where the size of groups (variadic or not) is 407 // provided as an attribute. For non-variadic elements, make sure to return 408 // an element rather than a singleton container. 409 if (op.getTrait(attrSizedTrait)) { 410 for (unsigned i = 0; i < numElements; ++i) { 411 const NamedTypeConstraint &element = getElement(op, i); 412 if (element.name.empty()) 413 continue; 414 std::string trailing; 415 if (!element.isVariableLength()) 416 trailing = "[0]"; 417 else if (element.isOptional()) 418 trailing = std::string( 419 formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 420 os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind, 421 i, trailing); 422 } 423 return; 424 } 425 426 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 427 } 428 429 /// Free function helpers accessing Operator components. 430 static int getNumOperands(const Operator &op) { return op.getNumOperands(); } 431 static const NamedTypeConstraint &getOperand(const Operator &op, int i) { 432 return op.getOperand(i); 433 } 434 static int getNumResults(const Operator &op) { return op.getNumResults(); } 435 static const NamedTypeConstraint &getResult(const Operator &op, int i) { 436 return op.getResult(i); 437 } 438 439 /// Emits accessors to Op operands. 440 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 441 emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(), 442 getNumOperands(op), getOperand); 443 } 444 445 /// Emits accessors Op results. 446 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 447 emitElementAccessors(op, os, "result", op.getNumVariableLengthResults(), 448 getNumResults(op), getResult); 449 } 450 451 /// Emits accessors to Op attributes. 452 static void emitAttributeAccessors(const Operator &op, raw_ostream &os) { 453 for (const auto &namedAttr : op.getAttributes()) { 454 // Skip "derived" attributes because they are just C++ functions that we 455 // don't currently expose. 456 if (namedAttr.attr.isDerivedAttr()) 457 continue; 458 459 if (namedAttr.name.empty()) 460 continue; 461 462 std::string sanitizedName = sanitizeName(namedAttr.name); 463 464 // Unit attributes are handled specially. 465 if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") { 466 os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name); 467 os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name); 468 os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name); 469 continue; 470 } 471 472 if (namedAttr.attr.isOptional()) { 473 os << formatv(optionalAttributeGetterTemplate, sanitizedName, 474 namedAttr.name); 475 os << formatv(optionalAttributeSetterTemplate, sanitizedName, 476 namedAttr.name); 477 os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name); 478 } else { 479 os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name); 480 os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name); 481 // Non-optional attributes cannot be deleted. 482 } 483 } 484 } 485 486 /// Template for the default auto-generated builder. 487 /// {0} is a comma-separated list of builder arguments, including the trailing 488 /// `loc` and `ip`; 489 /// {1} is the code populating `operands`, `results` and `attributes`, 490 /// `successors` fields. 491 constexpr const char *initTemplate = R"Py( 492 def __init__(self, {0}): 493 operands = [] 494 results = [] 495 attributes = {{} 496 regions = None 497 {1} 498 super().__init__({2}) 499 )Py"; 500 501 /// Template for appending a single element to the operand/result list. 502 /// {0} is the field name. 503 constexpr const char *singleOperandAppendTemplate = "operands.append({0})"; 504 constexpr const char *singleResultAppendTemplate = "results.append({0})"; 505 506 /// Template for appending an optional element to the operand/result list. 507 /// {0} is the field name. 508 constexpr const char *optionalAppendOperandTemplate = 509 "if {0} is not None: operands.append({0})"; 510 constexpr const char *optionalAppendAttrSizedOperandsTemplate = 511 "operands.append({0})"; 512 constexpr const char *optionalAppendResultTemplate = 513 "if {0} is not None: results.append({0})"; 514 515 /// Template for appending a list of elements to the operand/result list. 516 /// {0} is the field name. 517 constexpr const char *multiOperandAppendTemplate = 518 "operands.extend(_get_op_results_or_values({0}))"; 519 constexpr const char *multiOperandAppendPackTemplate = 520 "operands.append(_get_op_results_or_values({0}))"; 521 constexpr const char *multiResultAppendTemplate = "results.extend({0})"; 522 523 /// Template for attribute builder from raw input in the operation builder. 524 /// {0} is the builder argument name; 525 /// {1} is the attribute builder from raw; 526 /// {2} is the attribute builder from raw. 527 /// Use the value the user passed in if either it is already an Attribute or 528 /// there is no method registered to make it an Attribute. 529 constexpr const char *initAttributeWithBuilderTemplate = 530 R"Py(attributes["{1}"] = ({0} if ( 531 isinstance({0}, _ods_ir.Attribute) or 532 not _ods_ir.AttrBuilder.contains('{2}')) else 533 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py"; 534 535 /// Template for attribute builder from raw input for optional attribute in the 536 /// operation builder. 537 /// {0} is the builder argument name; 538 /// {1} is the attribute builder from raw; 539 /// {2} is the attribute builder from raw. 540 /// Use the value the user passed in if either it is already an Attribute or 541 /// there is no method registered to make it an Attribute. 542 constexpr const char *initOptionalAttributeWithBuilderTemplate = 543 R"Py(if {0} is not None: attributes["{1}"] = ({0} if ( 544 isinstance({0}, _ods_ir.Attribute) or 545 not _ods_ir.AttrBuilder.contains('{2}')) else 546 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py"; 547 548 constexpr const char *initUnitAttributeTemplate = 549 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( 550 _ods_get_default_loc_context(loc)))Py"; 551 552 /// Template to initialize the successors list in the builder if there are any 553 /// successors. 554 /// {0} is the value to initialize the successors list to. 555 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; 556 557 /// Template to append or extend the list of successors in the builder. 558 /// {0} is the list method ('append' or 'extend'); 559 /// {1} is the value to add. 560 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; 561 562 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer 563 /// result types of the given operation. 564 static bool hasSameArgumentAndResultTypes(const Operator &op) { 565 return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && 566 op.getNumVariableLengthResults() == 0; 567 } 568 569 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer 570 /// result types of the given operation. 571 static bool hasFirstAttrDerivedResultTypes(const Operator &op) { 572 return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && 573 op.getNumVariableLengthResults() == 0; 574 } 575 576 /// Returns true if the InferTypeOpInterface can be used to infer result types 577 /// of the given operation. 578 static bool hasInferTypeInterface(const Operator &op) { 579 return op.getTrait("::mlir::InferTypeOpInterface::Trait") && 580 op.getNumRegions() == 0; 581 } 582 583 /// Returns true if there is a trait or interface that can be used to infer 584 /// result types of the given operation. 585 static bool canInferType(const Operator &op) { 586 return hasSameArgumentAndResultTypes(op) || 587 hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); 588 } 589 590 /// Populates `builderArgs` with result names if the builder is expected to 591 /// accept them as arguments. 592 static void 593 populateBuilderArgsResults(const Operator &op, 594 SmallVectorImpl<std::string> &builderArgs) { 595 if (canInferType(op)) 596 return; 597 598 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 599 std::string name = op.getResultName(i).str(); 600 if (name.empty()) { 601 if (op.getNumResults() == 1) { 602 // Special case for one result, make the default name be 'result' 603 // to properly match the built-in result accessor. 604 name = "result"; 605 } else { 606 name = formatv("_gen_res_{0}", i); 607 } 608 } 609 name = sanitizeName(name); 610 builderArgs.push_back(name); 611 } 612 } 613 614 /// Populates `builderArgs` with the Python-compatible names of builder function 615 /// arguments using intermixed attributes and operands in the same order as they 616 /// appear in the `arguments` field of the op definition. Additionally, 617 /// `operandNames` is populated with names of operands in their order of 618 /// appearance. 619 static void populateBuilderArgs(const Operator &op, 620 SmallVectorImpl<std::string> &builderArgs, 621 SmallVectorImpl<std::string> &operandNames) { 622 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 623 std::string name = op.getArgName(i).str(); 624 if (name.empty()) 625 name = formatv("_gen_arg_{0}", i); 626 name = sanitizeName(name); 627 builderArgs.push_back(name); 628 if (!isa<NamedAttribute *>(op.getArg(i))) 629 operandNames.push_back(name); 630 } 631 } 632 633 /// Populates `builderArgs` with the Python-compatible names of builder function 634 /// successor arguments. Additionally, `successorArgNames` is also populated. 635 static void 636 populateBuilderArgsSuccessors(const Operator &op, 637 SmallVectorImpl<std::string> &builderArgs, 638 SmallVectorImpl<std::string> &successorArgNames) { 639 640 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { 641 NamedSuccessor successor = op.getSuccessor(i); 642 std::string name = std::string(successor.name); 643 if (name.empty()) 644 name = formatv("_gen_successor_{0}", i); 645 name = sanitizeName(name); 646 builderArgs.push_back(name); 647 successorArgNames.push_back(name); 648 } 649 } 650 651 /// Populates `builderLines` with additional lines that are required in the 652 /// builder to set up operation attributes. `argNames` is expected to contain 653 /// the names of builder arguments that correspond to op arguments, i.e. to the 654 /// operands and attributes in the same order as they appear in the `arguments` 655 /// field. 656 static void 657 populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames, 658 SmallVectorImpl<std::string> &builderLines) { 659 builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)"); 660 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 661 Argument arg = op.getArg(i); 662 auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg); 663 if (!attribute) 664 continue; 665 666 // Unit attributes are handled specially. 667 if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") { 668 builderLines.push_back( 669 formatv(initUnitAttributeTemplate, attribute->name, argNames[i])); 670 continue; 671 } 672 673 builderLines.push_back(formatv( 674 attribute->attr.isOptional() || attribute->attr.hasDefaultValue() 675 ? initOptionalAttributeWithBuilderTemplate 676 : initAttributeWithBuilderTemplate, 677 argNames[i], attribute->name, attribute->attr.getAttrDefName())); 678 } 679 } 680 681 /// Populates `builderLines` with additional lines that are required in the 682 /// builder to set up successors. successorArgNames is expected to correspond 683 /// to the Python argument name for each successor on the op. 684 static void 685 populateBuilderLinesSuccessors(const Operator &op, 686 ArrayRef<std::string> successorArgNames, 687 SmallVectorImpl<std::string> &builderLines) { 688 if (successorArgNames.empty()) { 689 builderLines.push_back(formatv(initSuccessorsTemplate, "None")); 690 return; 691 } 692 693 builderLines.push_back(formatv(initSuccessorsTemplate, "[]")); 694 for (int i = 0, e = successorArgNames.size(); i < e; ++i) { 695 auto &argName = successorArgNames[i]; 696 const NamedSuccessor &successor = op.getSuccessor(i); 697 builderLines.push_back(formatv(addSuccessorTemplate, 698 successor.isVariadic() ? "extend" : "append", 699 argName)); 700 } 701 } 702 703 /// Populates `builderLines` with additional lines that are required in the 704 /// builder to set up op operands. 705 static void 706 populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names, 707 SmallVectorImpl<std::string> &builderLines) { 708 bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; 709 710 // For each element, find or generate a name. 711 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 712 const NamedTypeConstraint &element = op.getOperand(i); 713 std::string name = names[i]; 714 715 // Choose the formatting string based on the element kind. 716 StringRef formatString; 717 if (!element.isVariableLength()) { 718 formatString = singleOperandAppendTemplate; 719 } else if (element.isOptional()) { 720 if (sizedSegments) { 721 formatString = optionalAppendAttrSizedOperandsTemplate; 722 } else { 723 formatString = optionalAppendOperandTemplate; 724 } 725 } else { 726 assert(element.isVariadic() && "unhandled element group type"); 727 // If emitting with sizedSegments, then we add the actual list-typed 728 // element. Otherwise, we extend the actual operands. 729 if (sizedSegments) { 730 formatString = multiOperandAppendPackTemplate; 731 } else { 732 formatString = multiOperandAppendTemplate; 733 } 734 } 735 736 builderLines.push_back(formatv(formatString.data(), name)); 737 } 738 } 739 740 /// Python code template for deriving the operation result types from its 741 /// attribute: 742 /// - {0} is the name of the attribute from which to derive the types. 743 constexpr const char *deriveTypeFromAttrTemplate = 744 R"Py(_ods_result_type_source_attr = attributes["{0}"] 745 _ods_derived_result_type = ( 746 _ods_ir.TypeAttr(_ods_result_type_source_attr).value 747 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else 748 _ods_result_type_source_attr.type))Py"; 749 750 /// Python code template appending {0} type {1} times to the results list. 751 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; 752 753 /// Appends the given multiline string as individual strings into 754 /// `builderLines`. 755 static void appendLineByLine(StringRef string, 756 SmallVectorImpl<std::string> &builderLines) { 757 758 std::pair<StringRef, StringRef> split = std::make_pair(string, string); 759 do { 760 split = split.second.split('\n'); 761 builderLines.push_back(split.first.str()); 762 } while (!split.second.empty()); 763 } 764 765 /// Populates `builderLines` with additional lines that are required in the 766 /// builder to set up op results. 767 static void 768 populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names, 769 SmallVectorImpl<std::string> &builderLines) { 770 bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; 771 772 if (hasSameArgumentAndResultTypes(op)) { 773 builderLines.push_back(formatv(appendSameResultsTemplate, 774 "operands[0].type", op.getNumResults())); 775 return; 776 } 777 778 if (hasFirstAttrDerivedResultTypes(op)) { 779 const NamedAttribute &firstAttr = op.getAttribute(0); 780 assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " 781 "from which the type is derived"); 782 appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), 783 builderLines); 784 builderLines.push_back(formatv(appendSameResultsTemplate, 785 "_ods_derived_result_type", 786 op.getNumResults())); 787 return; 788 } 789 790 if (hasInferTypeInterface(op)) 791 return; 792 793 // For each element, find or generate a name. 794 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 795 const NamedTypeConstraint &element = op.getResult(i); 796 std::string name = names[i]; 797 798 // Choose the formatting string based on the element kind. 799 StringRef formatString; 800 if (!element.isVariableLength()) { 801 formatString = singleResultAppendTemplate; 802 } else if (element.isOptional()) { 803 formatString = optionalAppendResultTemplate; 804 } else { 805 assert(element.isVariadic() && "unhandled element group type"); 806 // If emitting with sizedSegments, then we add the actual list-typed 807 // element. Otherwise, we extend the actual operands. 808 if (sizedSegments) { 809 formatString = singleResultAppendTemplate; 810 } else { 811 formatString = multiResultAppendTemplate; 812 } 813 } 814 815 builderLines.push_back(formatv(formatString.data(), name)); 816 } 817 } 818 819 /// If the operation has variadic regions, adds a builder argument to specify 820 /// the number of those regions and builder lines to forward it to the generic 821 /// constructor. 822 static void populateBuilderRegions(const Operator &op, 823 SmallVectorImpl<std::string> &builderArgs, 824 SmallVectorImpl<std::string> &builderLines) { 825 if (op.hasNoVariadicRegions()) 826 return; 827 828 // This is currently enforced when Operator is constructed. 829 assert(op.getNumVariadicRegions() == 1 && 830 op.getRegion(op.getNumRegions() - 1).isVariadic() && 831 "expected the last region to be varidic"); 832 833 const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1); 834 std::string name = 835 ("num_" + region.name.take_front().lower() + region.name.drop_front()) 836 .str(); 837 builderArgs.push_back(name); 838 builderLines.push_back( 839 formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); 840 } 841 842 /// Emits a default builder constructing an operation from the list of its 843 /// result types, followed by a list of its operands. Returns vector 844 /// of fully built functionArgs for downstream users (to save having to 845 /// rebuild anew). 846 static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op, 847 raw_ostream &os) { 848 SmallVector<std::string> builderArgs; 849 SmallVector<std::string> builderLines; 850 SmallVector<std::string> operandArgNames; 851 SmallVector<std::string> successorArgNames; 852 builderArgs.reserve(op.getNumOperands() + op.getNumResults() + 853 op.getNumNativeAttributes() + op.getNumSuccessors()); 854 populateBuilderArgsResults(op, builderArgs); 855 size_t numResultArgs = builderArgs.size(); 856 populateBuilderArgs(op, builderArgs, operandArgNames); 857 size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; 858 populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); 859 860 populateBuilderLinesOperand(op, operandArgNames, builderLines); 861 populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs), 862 builderLines); 863 populateBuilderLinesResult( 864 op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines); 865 populateBuilderLinesSuccessors(op, successorArgNames, builderLines); 866 populateBuilderRegions(op, builderArgs, builderLines); 867 868 // Layout of builderArgs vector elements: 869 // [ result_args operand_attr_args successor_args regions ] 870 871 // Determine whether the argument corresponding to a given index into the 872 // builderArgs vector is a python keyword argument or not. 873 auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { 874 // All result, successor, and region arguments are positional arguments. 875 if ((builderArgIndex < numResultArgs) || 876 (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) 877 return false; 878 // Keyword arguments: 879 // - optional named attributes (including unit attributes) 880 // - default-valued named attributes 881 // - optional operands 882 Argument a = op.getArg(builderArgIndex - numResultArgs); 883 if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a)) 884 return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); 885 if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a)) 886 return ntype->isOptional(); 887 return false; 888 }; 889 890 // StringRefs in functionArgs refer to strings allocated by builderArgs. 891 SmallVector<StringRef> functionArgs; 892 893 // Add positional arguments. 894 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { 895 if (!isKeywordArgFn(i)) 896 functionArgs.push_back(builderArgs[i]); 897 } 898 899 // Add a bare '*' to indicate that all following arguments must be keyword 900 // arguments. 901 functionArgs.push_back("*"); 902 903 // Add a default 'None' value to each keyword arg string, and then add to the 904 // function args list. 905 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { 906 if (isKeywordArgFn(i)) { 907 builderArgs[i].append("=None"); 908 functionArgs.push_back(builderArgs[i]); 909 } 910 } 911 functionArgs.push_back("loc=None"); 912 functionArgs.push_back("ip=None"); 913 914 SmallVector<std::string> initArgs; 915 initArgs.push_back("self.OPERATION_NAME"); 916 initArgs.push_back("self._ODS_REGIONS"); 917 initArgs.push_back("self._ODS_OPERAND_SEGMENTS"); 918 initArgs.push_back("self._ODS_RESULT_SEGMENTS"); 919 initArgs.push_back("attributes=attributes"); 920 if (!hasInferTypeInterface(op)) 921 initArgs.push_back("results=results"); 922 initArgs.push_back("operands=operands"); 923 initArgs.push_back("successors=_ods_successors"); 924 initArgs.push_back("regions=regions"); 925 initArgs.push_back("loc=loc"); 926 initArgs.push_back("ip=ip"); 927 928 os << formatv(initTemplate, llvm::join(functionArgs, ", "), 929 llvm::join(builderLines, "\n "), llvm::join(initArgs, ", ")); 930 return llvm::to_vector<8>( 931 llvm::map_range(functionArgs, [](StringRef s) { return s.str(); })); 932 } 933 934 static void emitSegmentSpec( 935 const Operator &op, const char *kind, 936 llvm::function_ref<int(const Operator &)> getNumElements, 937 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 938 getElement, 939 raw_ostream &os) { 940 std::string segmentSpec("["); 941 for (int i = 0, e = getNumElements(op); i < e; ++i) { 942 const NamedTypeConstraint &element = getElement(op, i); 943 if (element.isOptional()) { 944 segmentSpec.append("0,"); 945 } else if (element.isVariadic()) { 946 segmentSpec.append("-1,"); 947 } else { 948 segmentSpec.append("1,"); 949 } 950 } 951 segmentSpec.append("]"); 952 953 os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); 954 } 955 956 static void emitRegionAttributes(const Operator &op, raw_ostream &os) { 957 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). 958 // Note that the base OpView class defines this as (0, True). 959 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); 960 os << formatv(opClassRegionSpecTemplate, minRegionCount, 961 op.hasNoVariadicRegions() ? "True" : "False"); 962 } 963 964 /// Emits named accessors to regions. 965 static void emitRegionAccessors(const Operator &op, raw_ostream &os) { 966 for (const auto &en : llvm::enumerate(op.getRegions())) { 967 const NamedRegion ®ion = en.value(); 968 if (region.name.empty()) 969 continue; 970 971 assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && 972 "expected only the last region to be variadic"); 973 os << formatv(regionAccessorTemplate, sanitizeName(region.name), 974 std::to_string(en.index()) + 975 (region.isVariadic() ? ":" : "")); 976 } 977 } 978 979 /// Emits builder that extracts results from op 980 static void emitValueBuilder(const Operator &op, 981 SmallVector<std::string> functionArgs, 982 raw_ostream &os) { 983 // Params with (possibly) default args. 984 auto valueBuilderParams = 985 llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) { 986 SmallVector<StringRef> argMaybeDefault = 987 llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "=")); 988 auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]); 989 if (argMaybeDefault.size() == 2) 990 return arg + "=" + argMaybeDefault[1].str(); 991 return arg; 992 }); 993 // Actual args passed to op builder (e.g., opParam=op_param). 994 auto opBuilderArgs = llvm::map_range( 995 llvm::make_filter_range(functionArgs, 996 [](const std::string &s) { return s != "*"; }), 997 [](const std::string &arg) { 998 auto lhs = *llvm::split(arg, "=").begin(); 999 return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str(); 1000 }); 1001 std::string nameWithoutDialect = sanitizeName( 1002 op.getOperationName().substr(op.getOperationName().find('.') + 1)); 1003 std::string params = llvm::join(valueBuilderParams, ", "); 1004 std::string args = llvm::join(opBuilderArgs, ", "); 1005 const char *type = 1006 (op.getNumResults() > 1 1007 ? "_Sequence[_ods_ir.Value]" 1008 : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")); 1009 if (op.getNumVariableLengthResults() > 0) { 1010 os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect, 1011 op.getCppClassName(), params, args, type); 1012 } else { 1013 const char *results; 1014 if (op.getNumResults() == 0) { 1015 results = ""; 1016 } else if (op.getNumResults() == 1) { 1017 results = ".result"; 1018 } else { 1019 results = ".results"; 1020 } 1021 os << formatv(valueBuilderTemplate, nameWithoutDialect, 1022 op.getCppClassName(), params, args, type, results); 1023 } 1024 } 1025 1026 /// Emits bindings for a specific Op to the given output stream. 1027 static void emitOpBindings(const Operator &op, raw_ostream &os) { 1028 os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName()); 1029 1030 // Sized segments. 1031 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { 1032 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); 1033 } 1034 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { 1035 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); 1036 } 1037 1038 emitRegionAttributes(op, os); 1039 SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os); 1040 emitOperandAccessors(op, os); 1041 emitAttributeAccessors(op, os); 1042 emitResultAccessors(op, os); 1043 emitRegionAccessors(op, os); 1044 emitValueBuilder(op, functionArgs, os); 1045 } 1046 1047 /// Emits bindings for the dialect specified in the command line, including file 1048 /// headers and utilities. Returns `false` on success to comply with Tablegen 1049 /// registration requirements. 1050 static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) { 1051 if (clDialectName.empty()) 1052 llvm::PrintFatalError("dialect name not provided"); 1053 1054 os << fileHeader; 1055 if (!clDialectExtensionName.empty()) 1056 os << formatv(dialectExtensionTemplate, clDialectName.getValue()); 1057 else 1058 os << formatv(dialectClassTemplate, clDialectName.getValue()); 1059 1060 for (const Record *rec : records.getAllDerivedDefinitions("Op")) { 1061 Operator op(rec); 1062 if (op.getDialectName() == clDialectName.getValue()) 1063 emitOpBindings(op, os); 1064 } 1065 return false; 1066 } 1067 1068 static GenRegistration 1069 genPythonBindings("gen-python-op-bindings", 1070 "Generate Python bindings for MLIR Ops", &emitAllOps); 1071