xref: /llvm-project/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (revision acde3f722ff3766f6f793884108d342b78623fe4)
1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "OpGenHelpers.h"
15 
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "llvm/ADT/StringSet.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 
24 using namespace mlir;
25 using namespace mlir::tblgen;
26 using llvm::formatv;
27 using llvm::Record;
28 using llvm::RecordKeeper;
29 
30 /// File header and includes.
31 ///   {0} is the dialect namespace.
32 constexpr const char *fileHeader = R"Py(
33 # Autogenerated by mlir-tblgen; don't manually edit.
34 
35 from ._ods_common import _cext as _ods_cext
36 from ._ods_common import (
37     equally_sized_accessor as _ods_equally_sized_accessor,
38     get_default_loc_context as _ods_get_default_loc_context,
39     get_op_result_or_op_results as _get_op_result_or_op_results,
40     get_op_results_or_values as _get_op_results_or_values,
41     segmented_accessor as _ods_segmented_accessor,
42 )
43 _ods_ir = _ods_cext.ir
44 
45 import builtins
46 from typing import Sequence as _Sequence, Union as _Union
47 
48 )Py";
49 
50 /// Template for dialect class:
51 ///   {0} is the dialect namespace.
52 constexpr const char *dialectClassTemplate = R"Py(
53 @_ods_cext.register_dialect
54 class _Dialect(_ods_ir.Dialect):
55   DIALECT_NAMESPACE = "{0}"
56 )Py";
57 
58 constexpr const char *dialectExtensionTemplate = R"Py(
59 from ._{0}_ops_gen import _Dialect
60 )Py";
61 
62 /// Template for operation class:
63 ///   {0} is the Python class name;
64 ///   {1} is the operation name.
65 constexpr const char *opClassTemplate = R"Py(
66 @_ods_cext.register_operation(_Dialect)
67 class {0}(_ods_ir.OpView):
68   OPERATION_NAME = "{1}"
69 )Py";
70 
71 /// Template for class level declarations of operand and result
72 /// segment specs.
73 ///   {0} is either "OPERAND" or "RESULT"
74 ///   {1} is the segment spec
75 /// Each segment spec is either None (default) or an array of integers
76 /// where:
77 ///   1 = single element (expect non sequence operand/result)
78 ///   0 = optional element (expect a value or std::nullopt)
79 ///   -1 = operand/result is a sequence corresponding to a variadic
80 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
81   _ODS_{0}_SEGMENTS = {1}
82 )Py";
83 
84 /// Template for class level declarations of the _ODS_REGIONS spec:
85 ///   {0} is the minimum number of regions
86 ///   {1} is the Python bool literal for hasNoVariadicRegions
87 constexpr const char *opClassRegionSpecTemplate = R"Py(
88   _ODS_REGIONS = ({0}, {1})
89 )Py";
90 
91 /// Template for single-element accessor:
92 ///   {0} is the name of the accessor;
93 ///   {1} is either 'operand' or 'result';
94 ///   {2} is the position in the element list.
95 constexpr const char *opSingleTemplate = R"Py(
96   @builtins.property
97   def {0}(self):
98     return self.operation.{1}s[{2}]
99 )Py";
100 
101 /// Template for single-element accessor after a variable-length group:
102 ///   {0} is the name of the accessor;
103 ///   {1} is either 'operand' or 'result';
104 ///   {2} is the total number of element groups;
105 ///   {3} is the position of the current group in the group list.
106 /// This works for both a single variadic group (non-negative length) and an
107 /// single optional element (zero length if the element is absent).
108 constexpr const char *opSingleAfterVariableTemplate = R"Py(
109   @builtins.property
110   def {0}(self):
111     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
112     return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
113 )Py";
114 
115 /// Template for an optional element accessor:
116 ///   {0} is the name of the accessor;
117 ///   {1} is either 'operand' or 'result';
118 ///   {2} is the total number of element groups;
119 ///   {3} is the position of the current group in the group list.
120 /// This works if we have only one variable-length group (and it's the optional
121 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
122 /// smaller than the total number of groups.
123 constexpr const char *opOneOptionalTemplate = R"Py(
124   @builtins.property
125   def {0}(self):
126     return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
127 )Py";
128 
129 /// Template for the variadic group accessor in the single variadic group case:
130 ///   {0} is the name of the accessor;
131 ///   {1} is either 'operand' or 'result';
132 ///   {2} is the total number of element groups;
133 ///   {3} is the position of the current group in the group list.
134 constexpr const char *opOneVariadicTemplate = R"Py(
135   @builtins.property
136   def {0}(self):
137     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
138     return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
139 )Py";
140 
141 /// First part of the template for equally-sized variadic group accessor:
142 ///   {0} is the name of the accessor;
143 ///   {1} is either 'operand' or 'result';
144 ///   {2} is the total number of non-variadic groups;
145 ///   {3} is the total number of variadic groups;
146 ///   {4} is the number of non-variadic groups preceding the current group;
147 ///   {5} is the number of variadic groups preceding the current group.
148 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
149   @builtins.property
150   def {0}(self):
151     start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
152 
153 /// Second part of the template for equally-sized case, accessing a single
154 /// element:
155 ///   {0} is either 'operand' or 'result'.
156 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
157     return self.operation.{0}s[start]
158 )Py";
159 
160 /// Second part of the template for equally-sized case, accessing a variadic
161 /// group:
162 ///   {0} is either 'operand' or 'result'.
163 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
164     return self.operation.{0}s[start:start + elements_per_group]
165 )Py";
166 
167 /// Template for an attribute-sized group accessor:
168 ///   {0} is the name of the accessor;
169 ///   {1} is either 'operand' or 'result';
170 ///   {2} is the position of the group in the group list;
171 ///   {3} is a return suffix (expected [0] for single-element, empty for
172 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
173 constexpr const char *opVariadicSegmentTemplate = R"Py(
174   @builtins.property
175   def {0}(self):
176     {1}_range = _ods_segmented_accessor(
177          self.operation.{1}s,
178          self.operation.attributes["{1}SegmentSizes"], {2})
179     return {1}_range{3}
180 )Py";
181 
182 /// Template for a suffix when accessing an optional element in the
183 /// attribute-sized case:
184 ///   {0} is either 'operand' or 'result';
185 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
186     R"Py([0] if len({0}_range) > 0 else None)Py";
187 
188 /// Template for an operation attribute getter:
189 ///   {0} is the name of the attribute sanitized for Python;
190 ///   {1} is the original name of the attribute.
191 constexpr const char *attributeGetterTemplate = R"Py(
192   @builtins.property
193   def {0}(self):
194     return self.operation.attributes["{1}"]
195 )Py";
196 
197 /// Template for an optional operation attribute getter:
198 ///   {0} is the name of the attribute sanitized for Python;
199 ///   {1} is the original name of the attribute.
200 constexpr const char *optionalAttributeGetterTemplate = R"Py(
201   @builtins.property
202   def {0}(self):
203     if "{1}" not in self.operation.attributes:
204       return None
205     return self.operation.attributes["{1}"]
206 )Py";
207 
208 /// Template for a getter of a unit operation attribute, returns True of the
209 /// unit attribute is present, False otherwise (unit attributes have meaning
210 /// by mere presence):
211 ///    {0} is the name of the attribute sanitized for Python,
212 ///    {1} is the original name of the attribute.
213 constexpr const char *unitAttributeGetterTemplate = R"Py(
214   @builtins.property
215   def {0}(self):
216     return "{1}" in self.operation.attributes
217 )Py";
218 
219 /// Template for an operation attribute setter:
220 ///    {0} is the name of the attribute sanitized for Python;
221 ///    {1} is the original name of the attribute.
222 constexpr const char *attributeSetterTemplate = R"Py(
223   @{0}.setter
224   def {0}(self, value):
225     if value is None:
226       raise ValueError("'None' not allowed as value for mandatory attributes")
227     self.operation.attributes["{1}"] = value
228 )Py";
229 
230 /// Template for a setter of an optional operation attribute, setting to None
231 /// removes the attribute:
232 ///    {0} is the name of the attribute sanitized for Python;
233 ///    {1} is the original name of the attribute.
234 constexpr const char *optionalAttributeSetterTemplate = R"Py(
235   @{0}.setter
236   def {0}(self, value):
237     if value is not None:
238       self.operation.attributes["{1}"] = value
239     elif "{1}" in self.operation.attributes:
240       del self.operation.attributes["{1}"]
241 )Py";
242 
243 /// Template for a setter of a unit operation attribute, setting to None or
244 /// False removes the attribute:
245 ///    {0} is the name of the attribute sanitized for Python;
246 ///    {1} is the original name of the attribute.
247 constexpr const char *unitAttributeSetterTemplate = R"Py(
248   @{0}.setter
249   def {0}(self, value):
250     if bool(value):
251       self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
252     elif "{1}" in self.operation.attributes:
253       del self.operation.attributes["{1}"]
254 )Py";
255 
256 /// Template for a deleter of an optional or a unit operation attribute, removes
257 /// the attribute from the operation:
258 ///    {0} is the name of the attribute sanitized for Python;
259 ///    {1} is the original name of the attribute.
260 constexpr const char *attributeDeleterTemplate = R"Py(
261   @{0}.deleter
262   def {0}(self):
263     del self.operation.attributes["{1}"]
264 )Py";
265 
266 constexpr const char *regionAccessorTemplate = R"Py(
267   @builtins.property
268   def {0}(self):
269     return self.regions[{1}]
270 )Py";
271 
272 constexpr const char *valueBuilderTemplate = R"Py(
273 def {0}({2}) -> {4}:
274   return {1}({3}){5}
275 )Py";
276 
277 constexpr const char *valueBuilderVariadicTemplate = R"Py(
278 def {0}({2}) -> {4}:
279   return _get_op_result_or_op_results({1}({3}))
280 )Py";
281 
282 static llvm::cl::OptionCategory
283     clOpPythonBindingCat("Options for -gen-python-op-bindings");
284 
285 static llvm::cl::opt<std::string>
286     clDialectName("bind-dialect",
287                   llvm::cl::desc("The dialect to run the generator for"),
288                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
289 
290 static llvm::cl::opt<std::string> clDialectExtensionName(
291     "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
292     llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
293 
294 using AttributeClasses = DenseMap<StringRef, StringRef>;
295 
296 /// Checks whether `str` would shadow a generated variable or attribute
297 /// part of the OpView API.
298 static bool isODSReserved(StringRef str) {
299   static llvm::StringSet<> reserved(
300       {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
301        "loc", "verify", "regions", "results", "self", "operation",
302        "DIALECT_NAMESPACE", "OPERATION_NAME"});
303   return str.starts_with("_ods_") || str.ends_with("_ods") ||
304          reserved.contains(str);
305 }
306 
307 /// Modifies the `name` in a way that it becomes suitable for Python bindings
308 /// (does not change the `name` if it already is suitable) and returns the
309 /// modified version.
310 static std::string sanitizeName(StringRef name) {
311   std::string processedStr = name.str();
312   std::replace_if(
313       processedStr.begin(), processedStr.end(),
314       [](char c) { return !llvm::isAlnum(c); }, '_');
315 
316   if (llvm::isDigit(*processedStr.begin()))
317     return "_" + processedStr;
318 
319   if (isPythonReserved(processedStr) || isODSReserved(processedStr))
320     return processedStr + "_";
321   return processedStr;
322 }
323 
324 static std::string attrSizedTraitForKind(const char *kind) {
325   return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
326                  StringRef(kind).take_front().upper(),
327                  StringRef(kind).drop_front());
328 }
329 
330 /// Emits accessors to "elements" of an Op definition. Currently, the supported
331 /// elements are operands and results, indicated by `kind`, which must be either
332 /// `operand` or `result` and is used verbatim in the emitted code.
333 static void emitElementAccessors(
334     const Operator &op, raw_ostream &os, const char *kind,
335     unsigned numVariadicGroups, unsigned numElements,
336     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
337         getElement) {
338   assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
339                             kind) &&
340          "unsupported kind");
341 
342   // Traits indicating how to process variadic elements.
343   std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
344                                       StringRef(kind).take_front().upper(),
345                                       StringRef(kind).drop_front());
346   std::string attrSizedTrait = attrSizedTraitForKind(kind);
347 
348   // If there is only one variable-length element group, its size can be
349   // inferred from the total number of elements. If there are none, the
350   // generation is straightforward.
351   if (numVariadicGroups <= 1) {
352     bool seenVariableLength = false;
353     for (unsigned i = 0; i < numElements; ++i) {
354       const NamedTypeConstraint &element = getElement(op, i);
355       if (element.isVariableLength())
356         seenVariableLength = true;
357       if (element.name.empty())
358         continue;
359       if (element.isVariableLength()) {
360         os << formatv(element.isOptional() ? opOneOptionalTemplate
361                                            : opOneVariadicTemplate,
362                       sanitizeName(element.name), kind, numElements, i);
363       } else if (seenVariableLength) {
364         os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
365                       kind, numElements, i);
366       } else {
367         os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
368       }
369     }
370     return;
371   }
372 
373   // Handle the operations where variadic groups have the same size.
374   if (op.getTrait(sameSizeTrait)) {
375     // Count the number of simple elements
376     unsigned numSimpleLength = 0;
377     for (unsigned i = 0; i < numElements; ++i) {
378       const NamedTypeConstraint &element = getElement(op, i);
379       if (!element.isVariableLength()) {
380         ++numSimpleLength;
381       }
382     }
383 
384     // Generate the accessors
385     int numPrecedingSimple = 0;
386     int numPrecedingVariadic = 0;
387     for (unsigned i = 0; i < numElements; ++i) {
388       const NamedTypeConstraint &element = getElement(op, i);
389       if (!element.name.empty()) {
390         os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
391                       kind, numSimpleLength, numVariadicGroups,
392                       numPrecedingSimple, numPrecedingVariadic);
393         os << formatv(element.isVariableLength()
394                           ? opVariadicEqualVariadicTemplate
395                           : opVariadicEqualSimpleTemplate,
396                       kind);
397       }
398       if (element.isVariableLength())
399         ++numPrecedingVariadic;
400       else
401         ++numPrecedingSimple;
402     }
403     return;
404   }
405 
406   // Handle the operations where the size of groups (variadic or not) is
407   // provided as an attribute. For non-variadic elements, make sure to return
408   // an element rather than a singleton container.
409   if (op.getTrait(attrSizedTrait)) {
410     for (unsigned i = 0; i < numElements; ++i) {
411       const NamedTypeConstraint &element = getElement(op, i);
412       if (element.name.empty())
413         continue;
414       std::string trailing;
415       if (!element.isVariableLength())
416         trailing = "[0]";
417       else if (element.isOptional())
418         trailing = std::string(
419             formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
420       os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
421                     i, trailing);
422     }
423     return;
424   }
425 
426   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
427 }
428 
429 /// Free function helpers accessing Operator components.
430 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
431 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
432   return op.getOperand(i);
433 }
434 static int getNumResults(const Operator &op) { return op.getNumResults(); }
435 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
436   return op.getResult(i);
437 }
438 
439 /// Emits accessors to Op operands.
440 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
441   emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(),
442                        getNumOperands(op), getOperand);
443 }
444 
445 /// Emits accessors Op results.
446 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
447   emitElementAccessors(op, os, "result", op.getNumVariableLengthResults(),
448                        getNumResults(op), getResult);
449 }
450 
451 /// Emits accessors to Op attributes.
452 static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
453   for (const auto &namedAttr : op.getAttributes()) {
454     // Skip "derived" attributes because they are just C++ functions that we
455     // don't currently expose.
456     if (namedAttr.attr.isDerivedAttr())
457       continue;
458 
459     if (namedAttr.name.empty())
460       continue;
461 
462     std::string sanitizedName = sanitizeName(namedAttr.name);
463 
464     // Unit attributes are handled specially.
465     if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
466       os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name);
467       os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name);
468       os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
469       continue;
470     }
471 
472     if (namedAttr.attr.isOptional()) {
473       os << formatv(optionalAttributeGetterTemplate, sanitizedName,
474                     namedAttr.name);
475       os << formatv(optionalAttributeSetterTemplate, sanitizedName,
476                     namedAttr.name);
477       os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
478     } else {
479       os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name);
480       os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name);
481       // Non-optional attributes cannot be deleted.
482     }
483   }
484 }
485 
486 /// Template for the default auto-generated builder.
487 ///   {0} is a comma-separated list of builder arguments, including the trailing
488 ///       `loc` and `ip`;
489 ///   {1} is the code populating `operands`, `results` and `attributes`,
490 ///       `successors` fields.
491 constexpr const char *initTemplate = R"Py(
492   def __init__(self, {0}):
493     operands = []
494     results = []
495     attributes = {{}
496     regions = None
497     {1}
498     super().__init__({2})
499 )Py";
500 
501 /// Template for appending a single element to the operand/result list.
502 ///   {0} is the field name.
503 constexpr const char *singleOperandAppendTemplate = "operands.append({0})";
504 constexpr const char *singleResultAppendTemplate = "results.append({0})";
505 
506 /// Template for appending an optional element to the operand/result list.
507 ///   {0} is the field name.
508 constexpr const char *optionalAppendOperandTemplate =
509     "if {0} is not None: operands.append({0})";
510 constexpr const char *optionalAppendAttrSizedOperandsTemplate =
511     "operands.append({0})";
512 constexpr const char *optionalAppendResultTemplate =
513     "if {0} is not None: results.append({0})";
514 
515 /// Template for appending a list of elements to the operand/result list.
516 ///   {0} is the field name.
517 constexpr const char *multiOperandAppendTemplate =
518     "operands.extend(_get_op_results_or_values({0}))";
519 constexpr const char *multiOperandAppendPackTemplate =
520     "operands.append(_get_op_results_or_values({0}))";
521 constexpr const char *multiResultAppendTemplate = "results.extend({0})";
522 
523 /// Template for attribute builder from raw input in the operation builder.
524 ///   {0} is the builder argument name;
525 ///   {1} is the attribute builder from raw;
526 ///   {2} is the attribute builder from raw.
527 /// Use the value the user passed in if either it is already an Attribute or
528 /// there is no method registered to make it an Attribute.
529 constexpr const char *initAttributeWithBuilderTemplate =
530     R"Py(attributes["{1}"] = ({0} if (
531     isinstance({0}, _ods_ir.Attribute) or
532     not _ods_ir.AttrBuilder.contains('{2}')) else
533       _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
534 
535 /// Template for attribute builder from raw input for optional attribute in the
536 /// operation builder.
537 ///   {0} is the builder argument name;
538 ///   {1} is the attribute builder from raw;
539 ///   {2} is the attribute builder from raw.
540 /// Use the value the user passed in if either it is already an Attribute or
541 /// there is no method registered to make it an Attribute.
542 constexpr const char *initOptionalAttributeWithBuilderTemplate =
543     R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
544         isinstance({0}, _ods_ir.Attribute) or
545         not _ods_ir.AttrBuilder.contains('{2}')) else
546           _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
547 
548 constexpr const char *initUnitAttributeTemplate =
549     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
550       _ods_get_default_loc_context(loc)))Py";
551 
552 /// Template to initialize the successors list in the builder if there are any
553 /// successors.
554 ///   {0} is the value to initialize the successors list to.
555 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
556 
557 /// Template to append or extend the list of successors in the builder.
558 ///   {0} is the list method ('append' or 'extend');
559 ///   {1} is the value to add.
560 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
561 
562 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
563 /// result types of the given operation.
564 static bool hasSameArgumentAndResultTypes(const Operator &op) {
565   return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
566          op.getNumVariableLengthResults() == 0;
567 }
568 
569 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
570 /// result types of the given operation.
571 static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
572   return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
573          op.getNumVariableLengthResults() == 0;
574 }
575 
576 /// Returns true if the InferTypeOpInterface can be used to infer result types
577 /// of the given operation.
578 static bool hasInferTypeInterface(const Operator &op) {
579   return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
580          op.getNumRegions() == 0;
581 }
582 
583 /// Returns true if there is a trait or interface that can be used to infer
584 /// result types of the given operation.
585 static bool canInferType(const Operator &op) {
586   return hasSameArgumentAndResultTypes(op) ||
587          hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
588 }
589 
590 /// Populates `builderArgs` with result names if the builder is expected to
591 /// accept them as arguments.
592 static void
593 populateBuilderArgsResults(const Operator &op,
594                            SmallVectorImpl<std::string> &builderArgs) {
595   if (canInferType(op))
596     return;
597 
598   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
599     std::string name = op.getResultName(i).str();
600     if (name.empty()) {
601       if (op.getNumResults() == 1) {
602         // Special case for one result, make the default name be 'result'
603         // to properly match the built-in result accessor.
604         name = "result";
605       } else {
606         name = formatv("_gen_res_{0}", i);
607       }
608     }
609     name = sanitizeName(name);
610     builderArgs.push_back(name);
611   }
612 }
613 
614 /// Populates `builderArgs` with the Python-compatible names of builder function
615 /// arguments using intermixed attributes and operands in the same order as they
616 /// appear in the `arguments` field of the op definition. Additionally,
617 /// `operandNames` is populated with names of operands in their order of
618 /// appearance.
619 static void populateBuilderArgs(const Operator &op,
620                                 SmallVectorImpl<std::string> &builderArgs,
621                                 SmallVectorImpl<std::string> &operandNames) {
622   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
623     std::string name = op.getArgName(i).str();
624     if (name.empty())
625       name = formatv("_gen_arg_{0}", i);
626     name = sanitizeName(name);
627     builderArgs.push_back(name);
628     if (!isa<NamedAttribute *>(op.getArg(i)))
629       operandNames.push_back(name);
630   }
631 }
632 
633 /// Populates `builderArgs` with the Python-compatible names of builder function
634 /// successor arguments. Additionally, `successorArgNames` is also populated.
635 static void
636 populateBuilderArgsSuccessors(const Operator &op,
637                               SmallVectorImpl<std::string> &builderArgs,
638                               SmallVectorImpl<std::string> &successorArgNames) {
639 
640   for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
641     NamedSuccessor successor = op.getSuccessor(i);
642     std::string name = std::string(successor.name);
643     if (name.empty())
644       name = formatv("_gen_successor_{0}", i);
645     name = sanitizeName(name);
646     builderArgs.push_back(name);
647     successorArgNames.push_back(name);
648   }
649 }
650 
651 /// Populates `builderLines` with additional lines that are required in the
652 /// builder to set up operation attributes. `argNames` is expected to contain
653 /// the names of builder arguments that correspond to op arguments, i.e. to the
654 /// operands and attributes in the same order as they appear in the `arguments`
655 /// field.
656 static void
657 populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
658                          SmallVectorImpl<std::string> &builderLines) {
659   builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
660   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
661     Argument arg = op.getArg(i);
662     auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg);
663     if (!attribute)
664       continue;
665 
666     // Unit attributes are handled specially.
667     if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
668       builderLines.push_back(
669           formatv(initUnitAttributeTemplate, attribute->name, argNames[i]));
670       continue;
671     }
672 
673     builderLines.push_back(formatv(
674         attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
675             ? initOptionalAttributeWithBuilderTemplate
676             : initAttributeWithBuilderTemplate,
677         argNames[i], attribute->name, attribute->attr.getAttrDefName()));
678   }
679 }
680 
681 /// Populates `builderLines` with additional lines that are required in the
682 /// builder to set up successors. successorArgNames is expected to correspond
683 /// to the Python argument name for each successor on the op.
684 static void
685 populateBuilderLinesSuccessors(const Operator &op,
686                                ArrayRef<std::string> successorArgNames,
687                                SmallVectorImpl<std::string> &builderLines) {
688   if (successorArgNames.empty()) {
689     builderLines.push_back(formatv(initSuccessorsTemplate, "None"));
690     return;
691   }
692 
693   builderLines.push_back(formatv(initSuccessorsTemplate, "[]"));
694   for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
695     auto &argName = successorArgNames[i];
696     const NamedSuccessor &successor = op.getSuccessor(i);
697     builderLines.push_back(formatv(addSuccessorTemplate,
698                                    successor.isVariadic() ? "extend" : "append",
699                                    argName));
700   }
701 }
702 
703 /// Populates `builderLines` with additional lines that are required in the
704 /// builder to set up op operands.
705 static void
706 populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
707                             SmallVectorImpl<std::string> &builderLines) {
708   bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
709 
710   // For each element, find or generate a name.
711   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
712     const NamedTypeConstraint &element = op.getOperand(i);
713     std::string name = names[i];
714 
715     // Choose the formatting string based on the element kind.
716     StringRef formatString;
717     if (!element.isVariableLength()) {
718       formatString = singleOperandAppendTemplate;
719     } else if (element.isOptional()) {
720       if (sizedSegments) {
721         formatString = optionalAppendAttrSizedOperandsTemplate;
722       } else {
723         formatString = optionalAppendOperandTemplate;
724       }
725     } else {
726       assert(element.isVariadic() && "unhandled element group type");
727       // If emitting with sizedSegments, then we add the actual list-typed
728       // element. Otherwise, we extend the actual operands.
729       if (sizedSegments) {
730         formatString = multiOperandAppendPackTemplate;
731       } else {
732         formatString = multiOperandAppendTemplate;
733       }
734     }
735 
736     builderLines.push_back(formatv(formatString.data(), name));
737   }
738 }
739 
740 /// Python code template for deriving the operation result types from its
741 /// attribute:
742 ///   - {0} is the name of the attribute from which to derive the types.
743 constexpr const char *deriveTypeFromAttrTemplate =
744     R"Py(_ods_result_type_source_attr = attributes["{0}"]
745 _ods_derived_result_type = (
746     _ods_ir.TypeAttr(_ods_result_type_source_attr).value
747     if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
748     _ods_result_type_source_attr.type))Py";
749 
750 /// Python code template appending {0} type {1} times to the results list.
751 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
752 
753 /// Appends the given multiline string as individual strings into
754 /// `builderLines`.
755 static void appendLineByLine(StringRef string,
756                              SmallVectorImpl<std::string> &builderLines) {
757 
758   std::pair<StringRef, StringRef> split = std::make_pair(string, string);
759   do {
760     split = split.second.split('\n');
761     builderLines.push_back(split.first.str());
762   } while (!split.second.empty());
763 }
764 
765 /// Populates `builderLines` with additional lines that are required in the
766 /// builder to set up op results.
767 static void
768 populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
769                            SmallVectorImpl<std::string> &builderLines) {
770   bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
771 
772   if (hasSameArgumentAndResultTypes(op)) {
773     builderLines.push_back(formatv(appendSameResultsTemplate,
774                                    "operands[0].type", op.getNumResults()));
775     return;
776   }
777 
778   if (hasFirstAttrDerivedResultTypes(op)) {
779     const NamedAttribute &firstAttr = op.getAttribute(0);
780     assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
781                                       "from which the type is derived");
782     appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
783                      builderLines);
784     builderLines.push_back(formatv(appendSameResultsTemplate,
785                                    "_ods_derived_result_type",
786                                    op.getNumResults()));
787     return;
788   }
789 
790   if (hasInferTypeInterface(op))
791     return;
792 
793   // For each element, find or generate a name.
794   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
795     const NamedTypeConstraint &element = op.getResult(i);
796     std::string name = names[i];
797 
798     // Choose the formatting string based on the element kind.
799     StringRef formatString;
800     if (!element.isVariableLength()) {
801       formatString = singleResultAppendTemplate;
802     } else if (element.isOptional()) {
803       formatString = optionalAppendResultTemplate;
804     } else {
805       assert(element.isVariadic() && "unhandled element group type");
806       // If emitting with sizedSegments, then we add the actual list-typed
807       // element. Otherwise, we extend the actual operands.
808       if (sizedSegments) {
809         formatString = singleResultAppendTemplate;
810       } else {
811         formatString = multiResultAppendTemplate;
812       }
813     }
814 
815     builderLines.push_back(formatv(formatString.data(), name));
816   }
817 }
818 
819 /// If the operation has variadic regions, adds a builder argument to specify
820 /// the number of those regions and builder lines to forward it to the generic
821 /// constructor.
822 static void populateBuilderRegions(const Operator &op,
823                                    SmallVectorImpl<std::string> &builderArgs,
824                                    SmallVectorImpl<std::string> &builderLines) {
825   if (op.hasNoVariadicRegions())
826     return;
827 
828   // This is currently enforced when Operator is constructed.
829   assert(op.getNumVariadicRegions() == 1 &&
830          op.getRegion(op.getNumRegions() - 1).isVariadic() &&
831          "expected the last region to be varidic");
832 
833   const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
834   std::string name =
835       ("num_" + region.name.take_front().lower() + region.name.drop_front())
836           .str();
837   builderArgs.push_back(name);
838   builderLines.push_back(
839       formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
840 }
841 
842 /// Emits a default builder constructing an operation from the list of its
843 /// result types, followed by a list of its operands. Returns vector
844 /// of fully built functionArgs for downstream users (to save having to
845 /// rebuild anew).
846 static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
847                                                      raw_ostream &os) {
848   SmallVector<std::string> builderArgs;
849   SmallVector<std::string> builderLines;
850   SmallVector<std::string> operandArgNames;
851   SmallVector<std::string> successorArgNames;
852   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
853                       op.getNumNativeAttributes() + op.getNumSuccessors());
854   populateBuilderArgsResults(op, builderArgs);
855   size_t numResultArgs = builderArgs.size();
856   populateBuilderArgs(op, builderArgs, operandArgNames);
857   size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
858   populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
859 
860   populateBuilderLinesOperand(op, operandArgNames, builderLines);
861   populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs),
862                            builderLines);
863   populateBuilderLinesResult(
864       op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
865   populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
866   populateBuilderRegions(op, builderArgs, builderLines);
867 
868   // Layout of builderArgs vector elements:
869   // [ result_args  operand_attr_args successor_args regions ]
870 
871   // Determine whether the argument corresponding to a given index into the
872   // builderArgs vector is a python keyword argument or not.
873   auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
874     // All result, successor, and region arguments are positional arguments.
875     if ((builderArgIndex < numResultArgs) ||
876         (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
877       return false;
878     // Keyword arguments:
879     // - optional named attributes (including unit attributes)
880     // - default-valued named attributes
881     // - optional operands
882     Argument a = op.getArg(builderArgIndex - numResultArgs);
883     if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
884       return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
885     if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
886       return ntype->isOptional();
887     return false;
888   };
889 
890   // StringRefs in functionArgs refer to strings allocated by builderArgs.
891   SmallVector<StringRef> functionArgs;
892 
893   // Add positional arguments.
894   for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
895     if (!isKeywordArgFn(i))
896       functionArgs.push_back(builderArgs[i]);
897   }
898 
899   // Add a bare '*' to indicate that all following arguments must be keyword
900   // arguments.
901   functionArgs.push_back("*");
902 
903   // Add a default 'None' value to each keyword arg string, and then add to the
904   // function args list.
905   for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
906     if (isKeywordArgFn(i)) {
907       builderArgs[i].append("=None");
908       functionArgs.push_back(builderArgs[i]);
909     }
910   }
911   functionArgs.push_back("loc=None");
912   functionArgs.push_back("ip=None");
913 
914   SmallVector<std::string> initArgs;
915   initArgs.push_back("self.OPERATION_NAME");
916   initArgs.push_back("self._ODS_REGIONS");
917   initArgs.push_back("self._ODS_OPERAND_SEGMENTS");
918   initArgs.push_back("self._ODS_RESULT_SEGMENTS");
919   initArgs.push_back("attributes=attributes");
920   if (!hasInferTypeInterface(op))
921     initArgs.push_back("results=results");
922   initArgs.push_back("operands=operands");
923   initArgs.push_back("successors=_ods_successors");
924   initArgs.push_back("regions=regions");
925   initArgs.push_back("loc=loc");
926   initArgs.push_back("ip=ip");
927 
928   os << formatv(initTemplate, llvm::join(functionArgs, ", "),
929                 llvm::join(builderLines, "\n    "), llvm::join(initArgs, ", "));
930   return llvm::to_vector<8>(
931       llvm::map_range(functionArgs, [](StringRef s) { return s.str(); }));
932 }
933 
934 static void emitSegmentSpec(
935     const Operator &op, const char *kind,
936     llvm::function_ref<int(const Operator &)> getNumElements,
937     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
938         getElement,
939     raw_ostream &os) {
940   std::string segmentSpec("[");
941   for (int i = 0, e = getNumElements(op); i < e; ++i) {
942     const NamedTypeConstraint &element = getElement(op, i);
943     if (element.isOptional()) {
944       segmentSpec.append("0,");
945     } else if (element.isVariadic()) {
946       segmentSpec.append("-1,");
947     } else {
948       segmentSpec.append("1,");
949     }
950   }
951   segmentSpec.append("]");
952 
953   os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
954 }
955 
956 static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
957   // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
958   // Note that the base OpView class defines this as (0, True).
959   unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
960   os << formatv(opClassRegionSpecTemplate, minRegionCount,
961                 op.hasNoVariadicRegions() ? "True" : "False");
962 }
963 
964 /// Emits named accessors to regions.
965 static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
966   for (const auto &en : llvm::enumerate(op.getRegions())) {
967     const NamedRegion &region = en.value();
968     if (region.name.empty())
969       continue;
970 
971     assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
972            "expected only the last region to be variadic");
973     os << formatv(regionAccessorTemplate, sanitizeName(region.name),
974                   std::to_string(en.index()) +
975                       (region.isVariadic() ? ":" : ""));
976   }
977 }
978 
979 /// Emits builder that extracts results from op
980 static void emitValueBuilder(const Operator &op,
981                              SmallVector<std::string> functionArgs,
982                              raw_ostream &os) {
983   // Params with (possibly) default args.
984   auto valueBuilderParams =
985       llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
986         SmallVector<StringRef> argMaybeDefault =
987             llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
988         auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
989         if (argMaybeDefault.size() == 2)
990           return arg + "=" + argMaybeDefault[1].str();
991         return arg;
992       });
993   // Actual args passed to op builder (e.g., opParam=op_param).
994   auto opBuilderArgs = llvm::map_range(
995       llvm::make_filter_range(functionArgs,
996                               [](const std::string &s) { return s != "*"; }),
997       [](const std::string &arg) {
998         auto lhs = *llvm::split(arg, "=").begin();
999         return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
1000       });
1001   std::string nameWithoutDialect = sanitizeName(
1002       op.getOperationName().substr(op.getOperationName().find('.') + 1));
1003   std::string params = llvm::join(valueBuilderParams, ", ");
1004   std::string args = llvm::join(opBuilderArgs, ", ");
1005   const char *type =
1006       (op.getNumResults() > 1
1007            ? "_Sequence[_ods_ir.Value]"
1008            : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"));
1009   if (op.getNumVariableLengthResults() > 0) {
1010     os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
1011                   op.getCppClassName(), params, args, type);
1012   } else {
1013     const char *results;
1014     if (op.getNumResults() == 0) {
1015       results = "";
1016     } else if (op.getNumResults() == 1) {
1017       results = ".result";
1018     } else {
1019       results = ".results";
1020     }
1021     os << formatv(valueBuilderTemplate, nameWithoutDialect,
1022                   op.getCppClassName(), params, args, type, results);
1023   }
1024 }
1025 
1026 /// Emits bindings for a specific Op to the given output stream.
1027 static void emitOpBindings(const Operator &op, raw_ostream &os) {
1028   os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName());
1029 
1030   // Sized segments.
1031   if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
1032     emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
1033   }
1034   if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
1035     emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
1036   }
1037 
1038   emitRegionAttributes(op, os);
1039   SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
1040   emitOperandAccessors(op, os);
1041   emitAttributeAccessors(op, os);
1042   emitResultAccessors(op, os);
1043   emitRegionAccessors(op, os);
1044   emitValueBuilder(op, functionArgs, os);
1045 }
1046 
1047 /// Emits bindings for the dialect specified in the command line, including file
1048 /// headers and utilities. Returns `false` on success to comply with Tablegen
1049 /// registration requirements.
1050 static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
1051   if (clDialectName.empty())
1052     llvm::PrintFatalError("dialect name not provided");
1053 
1054   os << fileHeader;
1055   if (!clDialectExtensionName.empty())
1056     os << formatv(dialectExtensionTemplate, clDialectName.getValue());
1057   else
1058     os << formatv(dialectClassTemplate, clDialectName.getValue());
1059 
1060   for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
1061     Operator op(rec);
1062     if (op.getDialectName() == clDialectName.getValue())
1063       emitOpBindings(op, os);
1064   }
1065   return false;
1066 }
1067 
1068 static GenRegistration
1069     genPythonBindings("gen-python-op-bindings",
1070                       "Generate Python bindings for MLIR Ops", &emitAllOps);
1071