xref: /llvm-project/mlir/lib/Bindings/Python/IRTypes.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===- IRTypes.cpp - Exports builtin and standard types -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // clang-format off
10 #include "IRModule.h"
11 #include "mlir/Bindings/Python/IRTypes.h"
12 // clang-format on
13 
14 #include <optional>
15 
16 #include "IRModule.h"
17 #include "NanobindUtils.h"
18 #include "mlir-c/BuiltinAttributes.h"
19 #include "mlir-c/BuiltinTypes.h"
20 #include "mlir-c/Support.h"
21 
22 namespace nb = nanobind;
23 using namespace mlir;
24 using namespace mlir::python;
25 
26 using llvm::SmallVector;
27 using llvm::Twine;
28 
29 namespace {
30 
31 /// Checks whether the given type is an integer or float type.
32 static int mlirTypeIsAIntegerOrFloat(MlirType type) {
33   return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
34          mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
35 }
36 
37 class PyIntegerType : public PyConcreteType<PyIntegerType> {
38 public:
39   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
40   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
41       mlirIntegerTypeGetTypeID;
42   static constexpr const char *pyClassName = "IntegerType";
43   using PyConcreteType::PyConcreteType;
44 
45   static void bindDerived(ClassTy &c) {
46     c.def_static(
47         "get_signless",
48         [](unsigned width, DefaultingPyMlirContext context) {
49           MlirType t = mlirIntegerTypeGet(context->get(), width);
50           return PyIntegerType(context->getRef(), t);
51         },
52         nb::arg("width"), nb::arg("context").none() = nb::none(),
53         "Create a signless integer type");
54     c.def_static(
55         "get_signed",
56         [](unsigned width, DefaultingPyMlirContext context) {
57           MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
58           return PyIntegerType(context->getRef(), t);
59         },
60         nb::arg("width"), nb::arg("context").none() = nb::none(),
61         "Create a signed integer type");
62     c.def_static(
63         "get_unsigned",
64         [](unsigned width, DefaultingPyMlirContext context) {
65           MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
66           return PyIntegerType(context->getRef(), t);
67         },
68         nb::arg("width"), nb::arg("context").none() = nb::none(),
69         "Create an unsigned integer type");
70     c.def_prop_ro(
71         "width",
72         [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
73         "Returns the width of the integer type");
74     c.def_prop_ro(
75         "is_signless",
76         [](PyIntegerType &self) -> bool {
77           return mlirIntegerTypeIsSignless(self);
78         },
79         "Returns whether this is a signless integer");
80     c.def_prop_ro(
81         "is_signed",
82         [](PyIntegerType &self) -> bool {
83           return mlirIntegerTypeIsSigned(self);
84         },
85         "Returns whether this is a signed integer");
86     c.def_prop_ro(
87         "is_unsigned",
88         [](PyIntegerType &self) -> bool {
89           return mlirIntegerTypeIsUnsigned(self);
90         },
91         "Returns whether this is an unsigned integer");
92   }
93 };
94 
95 /// Index Type subclass - IndexType.
96 class PyIndexType : public PyConcreteType<PyIndexType> {
97 public:
98   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
99   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
100       mlirIndexTypeGetTypeID;
101   static constexpr const char *pyClassName = "IndexType";
102   using PyConcreteType::PyConcreteType;
103 
104   static void bindDerived(ClassTy &c) {
105     c.def_static(
106         "get",
107         [](DefaultingPyMlirContext context) {
108           MlirType t = mlirIndexTypeGet(context->get());
109           return PyIndexType(context->getRef(), t);
110         },
111         nb::arg("context").none() = nb::none(), "Create a index type.");
112   }
113 };
114 
115 class PyFloatType : public PyConcreteType<PyFloatType> {
116 public:
117   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
118   static constexpr const char *pyClassName = "FloatType";
119   using PyConcreteType::PyConcreteType;
120 
121   static void bindDerived(ClassTy &c) {
122     c.def_prop_ro(
123         "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
124         "Returns the width of the floating-point type");
125   }
126 };
127 
128 /// Floating Point Type subclass - Float4E2M1FNType.
129 class PyFloat4E2M1FNType
130     : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
131 public:
132   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
133   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
134       mlirFloat4E2M1FNTypeGetTypeID;
135   static constexpr const char *pyClassName = "Float4E2M1FNType";
136   using PyConcreteType::PyConcreteType;
137 
138   static void bindDerived(ClassTy &c) {
139     c.def_static(
140         "get",
141         [](DefaultingPyMlirContext context) {
142           MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
143           return PyFloat4E2M1FNType(context->getRef(), t);
144         },
145         nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type.");
146   }
147 };
148 
149 /// Floating Point Type subclass - Float6E2M3FNType.
150 class PyFloat6E2M3FNType
151     : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
152 public:
153   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
154   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
155       mlirFloat6E2M3FNTypeGetTypeID;
156   static constexpr const char *pyClassName = "Float6E2M3FNType";
157   using PyConcreteType::PyConcreteType;
158 
159   static void bindDerived(ClassTy &c) {
160     c.def_static(
161         "get",
162         [](DefaultingPyMlirContext context) {
163           MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
164           return PyFloat6E2M3FNType(context->getRef(), t);
165         },
166         nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type.");
167   }
168 };
169 
170 /// Floating Point Type subclass - Float6E3M2FNType.
171 class PyFloat6E3M2FNType
172     : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
173 public:
174   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
175   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
176       mlirFloat6E3M2FNTypeGetTypeID;
177   static constexpr const char *pyClassName = "Float6E3M2FNType";
178   using PyConcreteType::PyConcreteType;
179 
180   static void bindDerived(ClassTy &c) {
181     c.def_static(
182         "get",
183         [](DefaultingPyMlirContext context) {
184           MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
185           return PyFloat6E3M2FNType(context->getRef(), t);
186         },
187         nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type.");
188   }
189 };
190 
191 /// Floating Point Type subclass - Float8E4M3FNType.
192 class PyFloat8E4M3FNType
193     : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
194 public:
195   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
196   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
197       mlirFloat8E4M3FNTypeGetTypeID;
198   static constexpr const char *pyClassName = "Float8E4M3FNType";
199   using PyConcreteType::PyConcreteType;
200 
201   static void bindDerived(ClassTy &c) {
202     c.def_static(
203         "get",
204         [](DefaultingPyMlirContext context) {
205           MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
206           return PyFloat8E4M3FNType(context->getRef(), t);
207         },
208         nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type.");
209   }
210 };
211 
212 /// Floating Point Type subclass - Float8E5M2Type.
213 class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
214 public:
215   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
216   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
217       mlirFloat8E5M2TypeGetTypeID;
218   static constexpr const char *pyClassName = "Float8E5M2Type";
219   using PyConcreteType::PyConcreteType;
220 
221   static void bindDerived(ClassTy &c) {
222     c.def_static(
223         "get",
224         [](DefaultingPyMlirContext context) {
225           MlirType t = mlirFloat8E5M2TypeGet(context->get());
226           return PyFloat8E5M2Type(context->getRef(), t);
227         },
228         nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type.");
229   }
230 };
231 
232 /// Floating Point Type subclass - Float8E4M3Type.
233 class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
234 public:
235   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
236   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
237       mlirFloat8E4M3TypeGetTypeID;
238   static constexpr const char *pyClassName = "Float8E4M3Type";
239   using PyConcreteType::PyConcreteType;
240 
241   static void bindDerived(ClassTy &c) {
242     c.def_static(
243         "get",
244         [](DefaultingPyMlirContext context) {
245           MlirType t = mlirFloat8E4M3TypeGet(context->get());
246           return PyFloat8E4M3Type(context->getRef(), t);
247         },
248         nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type.");
249   }
250 };
251 
252 /// Floating Point Type subclass - Float8E4M3FNUZ.
253 class PyFloat8E4M3FNUZType
254     : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
255 public:
256   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
257   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
258       mlirFloat8E4M3FNUZTypeGetTypeID;
259   static constexpr const char *pyClassName = "Float8E4M3FNUZType";
260   using PyConcreteType::PyConcreteType;
261 
262   static void bindDerived(ClassTy &c) {
263     c.def_static(
264         "get",
265         [](DefaultingPyMlirContext context) {
266           MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
267           return PyFloat8E4M3FNUZType(context->getRef(), t);
268         },
269         nb::arg("context").none() = nb::none(),
270         "Create a float8_e4m3fnuz type.");
271   }
272 };
273 
274 /// Floating Point Type subclass - Float8E4M3B11FNUZ.
275 class PyFloat8E4M3B11FNUZType
276     : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
277 public:
278   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
279   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
280       mlirFloat8E4M3B11FNUZTypeGetTypeID;
281   static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
282   using PyConcreteType::PyConcreteType;
283 
284   static void bindDerived(ClassTy &c) {
285     c.def_static(
286         "get",
287         [](DefaultingPyMlirContext context) {
288           MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
289           return PyFloat8E4M3B11FNUZType(context->getRef(), t);
290         },
291         nb::arg("context").none() = nb::none(),
292         "Create a float8_e4m3b11fnuz type.");
293   }
294 };
295 
296 /// Floating Point Type subclass - Float8E5M2FNUZ.
297 class PyFloat8E5M2FNUZType
298     : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
299 public:
300   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
301   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
302       mlirFloat8E5M2FNUZTypeGetTypeID;
303   static constexpr const char *pyClassName = "Float8E5M2FNUZType";
304   using PyConcreteType::PyConcreteType;
305 
306   static void bindDerived(ClassTy &c) {
307     c.def_static(
308         "get",
309         [](DefaultingPyMlirContext context) {
310           MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
311           return PyFloat8E5M2FNUZType(context->getRef(), t);
312         },
313         nb::arg("context").none() = nb::none(),
314         "Create a float8_e5m2fnuz type.");
315   }
316 };
317 
318 /// Floating Point Type subclass - Float8E3M4Type.
319 class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
320 public:
321   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
322   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
323       mlirFloat8E3M4TypeGetTypeID;
324   static constexpr const char *pyClassName = "Float8E3M4Type";
325   using PyConcreteType::PyConcreteType;
326 
327   static void bindDerived(ClassTy &c) {
328     c.def_static(
329         "get",
330         [](DefaultingPyMlirContext context) {
331           MlirType t = mlirFloat8E3M4TypeGet(context->get());
332           return PyFloat8E3M4Type(context->getRef(), t);
333         },
334         nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type.");
335   }
336 };
337 
338 /// Floating Point Type subclass - Float8E8M0FNUType.
339 class PyFloat8E8M0FNUType
340     : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
341 public:
342   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
343   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
344       mlirFloat8E8M0FNUTypeGetTypeID;
345   static constexpr const char *pyClassName = "Float8E8M0FNUType";
346   using PyConcreteType::PyConcreteType;
347 
348   static void bindDerived(ClassTy &c) {
349     c.def_static(
350         "get",
351         [](DefaultingPyMlirContext context) {
352           MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
353           return PyFloat8E8M0FNUType(context->getRef(), t);
354         },
355         nb::arg("context").none() = nb::none(),
356         "Create a float8_e8m0fnu type.");
357   }
358 };
359 
360 /// Floating Point Type subclass - BF16Type.
361 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
362 public:
363   static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
364   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
365       mlirBFloat16TypeGetTypeID;
366   static constexpr const char *pyClassName = "BF16Type";
367   using PyConcreteType::PyConcreteType;
368 
369   static void bindDerived(ClassTy &c) {
370     c.def_static(
371         "get",
372         [](DefaultingPyMlirContext context) {
373           MlirType t = mlirBF16TypeGet(context->get());
374           return PyBF16Type(context->getRef(), t);
375         },
376         nb::arg("context").none() = nb::none(), "Create a bf16 type.");
377   }
378 };
379 
380 /// Floating Point Type subclass - F16Type.
381 class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
382 public:
383   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
384   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
385       mlirFloat16TypeGetTypeID;
386   static constexpr const char *pyClassName = "F16Type";
387   using PyConcreteType::PyConcreteType;
388 
389   static void bindDerived(ClassTy &c) {
390     c.def_static(
391         "get",
392         [](DefaultingPyMlirContext context) {
393           MlirType t = mlirF16TypeGet(context->get());
394           return PyF16Type(context->getRef(), t);
395         },
396         nb::arg("context").none() = nb::none(), "Create a f16 type.");
397   }
398 };
399 
400 /// Floating Point Type subclass - TF32Type.
401 class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
402 public:
403   static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
404   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
405       mlirFloatTF32TypeGetTypeID;
406   static constexpr const char *pyClassName = "FloatTF32Type";
407   using PyConcreteType::PyConcreteType;
408 
409   static void bindDerived(ClassTy &c) {
410     c.def_static(
411         "get",
412         [](DefaultingPyMlirContext context) {
413           MlirType t = mlirTF32TypeGet(context->get());
414           return PyTF32Type(context->getRef(), t);
415         },
416         nb::arg("context").none() = nb::none(), "Create a tf32 type.");
417   }
418 };
419 
420 /// Floating Point Type subclass - F32Type.
421 class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
422 public:
423   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
424   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
425       mlirFloat32TypeGetTypeID;
426   static constexpr const char *pyClassName = "F32Type";
427   using PyConcreteType::PyConcreteType;
428 
429   static void bindDerived(ClassTy &c) {
430     c.def_static(
431         "get",
432         [](DefaultingPyMlirContext context) {
433           MlirType t = mlirF32TypeGet(context->get());
434           return PyF32Type(context->getRef(), t);
435         },
436         nb::arg("context").none() = nb::none(), "Create a f32 type.");
437   }
438 };
439 
440 /// Floating Point Type subclass - F64Type.
441 class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
442 public:
443   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
444   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
445       mlirFloat64TypeGetTypeID;
446   static constexpr const char *pyClassName = "F64Type";
447   using PyConcreteType::PyConcreteType;
448 
449   static void bindDerived(ClassTy &c) {
450     c.def_static(
451         "get",
452         [](DefaultingPyMlirContext context) {
453           MlirType t = mlirF64TypeGet(context->get());
454           return PyF64Type(context->getRef(), t);
455         },
456         nb::arg("context").none() = nb::none(), "Create a f64 type.");
457   }
458 };
459 
460 /// None Type subclass - NoneType.
461 class PyNoneType : public PyConcreteType<PyNoneType> {
462 public:
463   static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
464   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
465       mlirNoneTypeGetTypeID;
466   static constexpr const char *pyClassName = "NoneType";
467   using PyConcreteType::PyConcreteType;
468 
469   static void bindDerived(ClassTy &c) {
470     c.def_static(
471         "get",
472         [](DefaultingPyMlirContext context) {
473           MlirType t = mlirNoneTypeGet(context->get());
474           return PyNoneType(context->getRef(), t);
475         },
476         nb::arg("context").none() = nb::none(), "Create a none type.");
477   }
478 };
479 
480 /// Complex Type subclass - ComplexType.
481 class PyComplexType : public PyConcreteType<PyComplexType> {
482 public:
483   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
484   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
485       mlirComplexTypeGetTypeID;
486   static constexpr const char *pyClassName = "ComplexType";
487   using PyConcreteType::PyConcreteType;
488 
489   static void bindDerived(ClassTy &c) {
490     c.def_static(
491         "get",
492         [](PyType &elementType) {
493           // The element must be a floating point or integer scalar type.
494           if (mlirTypeIsAIntegerOrFloat(elementType)) {
495             MlirType t = mlirComplexTypeGet(elementType);
496             return PyComplexType(elementType.getContext(), t);
497           }
498           throw nb::value_error(
499               (Twine("invalid '") +
500                nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
501                "' and expected floating point or integer type.")
502                   .str()
503                   .c_str());
504         },
505         "Create a complex type");
506     c.def_prop_ro(
507         "element_type",
508         [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
509         "Returns element type.");
510   }
511 };
512 
513 } // namespace
514 
515 // Shaped Type Interface - ShapedType
516 void mlir::PyShapedType::bindDerived(ClassTy &c) {
517   c.def_prop_ro(
518       "element_type",
519       [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
520       "Returns the element type of the shaped type.");
521   c.def_prop_ro(
522       "has_rank",
523       [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
524       "Returns whether the given shaped type is ranked.");
525   c.def_prop_ro(
526       "rank",
527       [](PyShapedType &self) {
528         self.requireHasRank();
529         return mlirShapedTypeGetRank(self);
530       },
531       "Returns the rank of the given ranked shaped type.");
532   c.def_prop_ro(
533       "has_static_shape",
534       [](PyShapedType &self) -> bool {
535         return mlirShapedTypeHasStaticShape(self);
536       },
537       "Returns whether the given shaped type has a static shape.");
538   c.def(
539       "is_dynamic_dim",
540       [](PyShapedType &self, intptr_t dim) -> bool {
541         self.requireHasRank();
542         return mlirShapedTypeIsDynamicDim(self, dim);
543       },
544       nb::arg("dim"),
545       "Returns whether the dim-th dimension of the given shaped type is "
546       "dynamic.");
547   c.def(
548       "get_dim_size",
549       [](PyShapedType &self, intptr_t dim) {
550         self.requireHasRank();
551         return mlirShapedTypeGetDimSize(self, dim);
552       },
553       nb::arg("dim"),
554       "Returns the dim-th dimension of the given ranked shaped type.");
555   c.def_static(
556       "is_dynamic_size",
557       [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
558       nb::arg("dim_size"),
559       "Returns whether the given dimension size indicates a dynamic "
560       "dimension.");
561   c.def(
562       "is_dynamic_stride_or_offset",
563       [](PyShapedType &self, int64_t val) -> bool {
564         self.requireHasRank();
565         return mlirShapedTypeIsDynamicStrideOrOffset(val);
566       },
567       nb::arg("dim_size"),
568       "Returns whether the given value is used as a placeholder for dynamic "
569       "strides and offsets in shaped types.");
570   c.def_prop_ro(
571       "shape",
572       [](PyShapedType &self) {
573         self.requireHasRank();
574 
575         std::vector<int64_t> shape;
576         int64_t rank = mlirShapedTypeGetRank(self);
577         shape.reserve(rank);
578         for (int64_t i = 0; i < rank; ++i)
579           shape.push_back(mlirShapedTypeGetDimSize(self, i));
580         return shape;
581       },
582       "Returns the shape of the ranked shaped type as a list of integers.");
583   c.def_static(
584       "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
585       "Returns the value used to indicate dynamic dimensions in shaped "
586       "types.");
587   c.def_static(
588       "get_dynamic_stride_or_offset",
589       []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
590       "Returns the value used to indicate dynamic strides or offsets in "
591       "shaped types.");
592 }
593 
594 void mlir::PyShapedType::requireHasRank() {
595   if (!mlirShapedTypeHasRank(*this)) {
596     throw nb::value_error(
597         "calling this method requires that the type has a rank.");
598   }
599 }
600 
601 const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction =
602     mlirTypeIsAShaped;
603 
604 namespace {
605 
606 /// Vector Type subclass - VectorType.
607 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
608 public:
609   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
610   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
611       mlirVectorTypeGetTypeID;
612   static constexpr const char *pyClassName = "VectorType";
613   using PyConcreteType::PyConcreteType;
614 
615   static void bindDerived(ClassTy &c) {
616     c.def_static("get", &PyVectorType::get, nb::arg("shape"),
617                  nb::arg("element_type"), nb::kw_only(),
618                  nb::arg("scalable").none() = nb::none(),
619                  nb::arg("scalable_dims").none() = nb::none(),
620                  nb::arg("loc").none() = nb::none(), "Create a vector type")
621         .def_prop_ro(
622             "scalable",
623             [](MlirType self) { return mlirVectorTypeIsScalable(self); })
624         .def_prop_ro("scalable_dims", [](MlirType self) {
625           std::vector<bool> scalableDims;
626           size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
627           scalableDims.reserve(rank);
628           for (size_t i = 0; i < rank; ++i)
629             scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
630           return scalableDims;
631         });
632   }
633 
634 private:
635   static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
636                           std::optional<nb::list> scalable,
637                           std::optional<std::vector<int64_t>> scalableDims,
638                           DefaultingPyLocation loc) {
639     if (scalable && scalableDims) {
640       throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
641                             "are mutually exclusive.");
642     }
643 
644     PyMlirContext::ErrorCapture errors(loc->getContext());
645     MlirType type;
646     if (scalable) {
647       if (scalable->size() != shape.size())
648         throw nb::value_error("Expected len(scalable) == len(shape).");
649 
650       SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
651           *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
652       type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
653                                               scalableDimFlags.data(),
654                                               elementType);
655     } else if (scalableDims) {
656       SmallVector<bool> scalableDimFlags(shape.size(), false);
657       for (int64_t dim : *scalableDims) {
658         if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
659           throw nb::value_error("Scalable dimension index out of bounds.");
660         scalableDimFlags[dim] = true;
661       }
662       type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
663                                               scalableDimFlags.data(),
664                                               elementType);
665     } else {
666       type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
667                                       elementType);
668     }
669     if (mlirTypeIsNull(type))
670       throw MLIRError("Invalid type", errors.take());
671     return PyVectorType(elementType.getContext(), type);
672   }
673 };
674 
675 /// Ranked Tensor Type subclass - RankedTensorType.
676 class PyRankedTensorType
677     : public PyConcreteType<PyRankedTensorType, PyShapedType> {
678 public:
679   static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
680   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
681       mlirRankedTensorTypeGetTypeID;
682   static constexpr const char *pyClassName = "RankedTensorType";
683   using PyConcreteType::PyConcreteType;
684 
685   static void bindDerived(ClassTy &c) {
686     c.def_static(
687         "get",
688         [](std::vector<int64_t> shape, PyType &elementType,
689            std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
690           PyMlirContext::ErrorCapture errors(loc->getContext());
691           MlirType t = mlirRankedTensorTypeGetChecked(
692               loc, shape.size(), shape.data(), elementType,
693               encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
694           if (mlirTypeIsNull(t))
695             throw MLIRError("Invalid type", errors.take());
696           return PyRankedTensorType(elementType.getContext(), t);
697         },
698         nb::arg("shape"), nb::arg("element_type"),
699         nb::arg("encoding").none() = nb::none(),
700         nb::arg("loc").none() = nb::none(), "Create a ranked tensor type");
701     c.def_prop_ro("encoding",
702                   [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
703                     MlirAttribute encoding =
704                         mlirRankedTensorTypeGetEncoding(self.get());
705                     if (mlirAttributeIsNull(encoding))
706                       return std::nullopt;
707                     return encoding;
708                   });
709   }
710 };
711 
712 /// Unranked Tensor Type subclass - UnrankedTensorType.
713 class PyUnrankedTensorType
714     : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
715 public:
716   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
717   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
718       mlirUnrankedTensorTypeGetTypeID;
719   static constexpr const char *pyClassName = "UnrankedTensorType";
720   using PyConcreteType::PyConcreteType;
721 
722   static void bindDerived(ClassTy &c) {
723     c.def_static(
724         "get",
725         [](PyType &elementType, DefaultingPyLocation loc) {
726           PyMlirContext::ErrorCapture errors(loc->getContext());
727           MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
728           if (mlirTypeIsNull(t))
729             throw MLIRError("Invalid type", errors.take());
730           return PyUnrankedTensorType(elementType.getContext(), t);
731         },
732         nb::arg("element_type"), nb::arg("loc").none() = nb::none(),
733         "Create a unranked tensor type");
734   }
735 };
736 
737 /// Ranked MemRef Type subclass - MemRefType.
738 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
739 public:
740   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
741   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
742       mlirMemRefTypeGetTypeID;
743   static constexpr const char *pyClassName = "MemRefType";
744   using PyConcreteType::PyConcreteType;
745 
746   static void bindDerived(ClassTy &c) {
747     c.def_static(
748          "get",
749          [](std::vector<int64_t> shape, PyType &elementType,
750             PyAttribute *layout, PyAttribute *memorySpace,
751             DefaultingPyLocation loc) {
752            PyMlirContext::ErrorCapture errors(loc->getContext());
753            MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
754            MlirAttribute memSpaceAttr =
755                memorySpace ? *memorySpace : mlirAttributeGetNull();
756            MlirType t =
757                mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
758                                         shape.data(), layoutAttr, memSpaceAttr);
759            if (mlirTypeIsNull(t))
760              throw MLIRError("Invalid type", errors.take());
761            return PyMemRefType(elementType.getContext(), t);
762          },
763          nb::arg("shape"), nb::arg("element_type"),
764          nb::arg("layout").none() = nb::none(),
765          nb::arg("memory_space").none() = nb::none(),
766          nb::arg("loc").none() = nb::none(), "Create a memref type")
767         .def_prop_ro(
768             "layout",
769             [](PyMemRefType &self) -> MlirAttribute {
770               return mlirMemRefTypeGetLayout(self);
771             },
772             "The layout of the MemRef type.")
773         .def(
774             "get_strides_and_offset",
775             [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
776               std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
777               int64_t offset;
778               if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
779                       self, strides.data(), &offset)))
780                 throw std::runtime_error(
781                     "Failed to extract strides and offset from memref.");
782               return {strides, offset};
783             },
784             "The strides and offset of the MemRef type.")
785         .def_prop_ro(
786             "affine_map",
787             [](PyMemRefType &self) -> PyAffineMap {
788               MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
789               return PyAffineMap(self.getContext(), map);
790             },
791             "The layout of the MemRef type as an affine map.")
792         .def_prop_ro(
793             "memory_space",
794             [](PyMemRefType &self) -> std::optional<MlirAttribute> {
795               MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
796               if (mlirAttributeIsNull(a))
797                 return std::nullopt;
798               return a;
799             },
800             "Returns the memory space of the given MemRef type.");
801   }
802 };
803 
804 /// Unranked MemRef Type subclass - UnrankedMemRefType.
805 class PyUnrankedMemRefType
806     : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
807 public:
808   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
809   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
810       mlirUnrankedMemRefTypeGetTypeID;
811   static constexpr const char *pyClassName = "UnrankedMemRefType";
812   using PyConcreteType::PyConcreteType;
813 
814   static void bindDerived(ClassTy &c) {
815     c.def_static(
816          "get",
817          [](PyType &elementType, PyAttribute *memorySpace,
818             DefaultingPyLocation loc) {
819            PyMlirContext::ErrorCapture errors(loc->getContext());
820            MlirAttribute memSpaceAttr = {};
821            if (memorySpace)
822              memSpaceAttr = *memorySpace;
823 
824            MlirType t =
825                mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
826            if (mlirTypeIsNull(t))
827              throw MLIRError("Invalid type", errors.take());
828            return PyUnrankedMemRefType(elementType.getContext(), t);
829          },
830          nb::arg("element_type"), nb::arg("memory_space").none(),
831          nb::arg("loc").none() = nb::none(), "Create a unranked memref type")
832         .def_prop_ro(
833             "memory_space",
834             [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
835               MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
836               if (mlirAttributeIsNull(a))
837                 return std::nullopt;
838               return a;
839             },
840             "Returns the memory space of the given Unranked MemRef type.");
841   }
842 };
843 
844 /// Tuple Type subclass - TupleType.
845 class PyTupleType : public PyConcreteType<PyTupleType> {
846 public:
847   static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
848   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
849       mlirTupleTypeGetTypeID;
850   static constexpr const char *pyClassName = "TupleType";
851   using PyConcreteType::PyConcreteType;
852 
853   static void bindDerived(ClassTy &c) {
854     c.def_static(
855         "get_tuple",
856         [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
857           MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
858                                         elements.data());
859           return PyTupleType(context->getRef(), t);
860         },
861         nb::arg("elements"), nb::arg("context").none() = nb::none(),
862         "Create a tuple type");
863     c.def(
864         "get_type",
865         [](PyTupleType &self, intptr_t pos) {
866           return mlirTupleTypeGetType(self, pos);
867         },
868         nb::arg("pos"), "Returns the pos-th type in the tuple type.");
869     c.def_prop_ro(
870         "num_types",
871         [](PyTupleType &self) -> intptr_t {
872           return mlirTupleTypeGetNumTypes(self);
873         },
874         "Returns the number of types contained in a tuple.");
875   }
876 };
877 
878 /// Function type.
879 class PyFunctionType : public PyConcreteType<PyFunctionType> {
880 public:
881   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
882   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
883       mlirFunctionTypeGetTypeID;
884   static constexpr const char *pyClassName = "FunctionType";
885   using PyConcreteType::PyConcreteType;
886 
887   static void bindDerived(ClassTy &c) {
888     c.def_static(
889         "get",
890         [](std::vector<MlirType> inputs, std::vector<MlirType> results,
891            DefaultingPyMlirContext context) {
892           MlirType t =
893               mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
894                                   results.size(), results.data());
895           return PyFunctionType(context->getRef(), t);
896         },
897         nb::arg("inputs"), nb::arg("results"),
898         nb::arg("context").none() = nb::none(),
899         "Gets a FunctionType from a list of input and result types");
900     c.def_prop_ro(
901         "inputs",
902         [](PyFunctionType &self) {
903           MlirType t = self;
904           nb::list types;
905           for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
906                ++i) {
907             types.append(mlirFunctionTypeGetInput(t, i));
908           }
909           return types;
910         },
911         "Returns the list of input types in the FunctionType.");
912     c.def_prop_ro(
913         "results",
914         [](PyFunctionType &self) {
915           nb::list types;
916           for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
917                ++i) {
918             types.append(mlirFunctionTypeGetResult(self, i));
919           }
920           return types;
921         },
922         "Returns the list of result types in the FunctionType.");
923   }
924 };
925 
926 static MlirStringRef toMlirStringRef(const std::string &s) {
927   return mlirStringRefCreate(s.data(), s.size());
928 }
929 
930 /// Opaque Type subclass - OpaqueType.
931 class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
932 public:
933   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
934   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
935       mlirOpaqueTypeGetTypeID;
936   static constexpr const char *pyClassName = "OpaqueType";
937   using PyConcreteType::PyConcreteType;
938 
939   static void bindDerived(ClassTy &c) {
940     c.def_static(
941         "get",
942         [](std::string dialectNamespace, std::string typeData,
943            DefaultingPyMlirContext context) {
944           MlirType type = mlirOpaqueTypeGet(context->get(),
945                                             toMlirStringRef(dialectNamespace),
946                                             toMlirStringRef(typeData));
947           return PyOpaqueType(context->getRef(), type);
948         },
949         nb::arg("dialect_namespace"), nb::arg("buffer"),
950         nb::arg("context").none() = nb::none(),
951         "Create an unregistered (opaque) dialect type.");
952     c.def_prop_ro(
953         "dialect_namespace",
954         [](PyOpaqueType &self) {
955           MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
956           return nb::str(stringRef.data, stringRef.length);
957         },
958         "Returns the dialect namespace for the Opaque type as a string.");
959     c.def_prop_ro(
960         "data",
961         [](PyOpaqueType &self) {
962           MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
963           return nb::str(stringRef.data, stringRef.length);
964         },
965         "Returns the data for the Opaque type as a string.");
966   }
967 };
968 
969 } // namespace
970 
971 void mlir::python::populateIRTypes(nb::module_ &m) {
972   PyIntegerType::bind(m);
973   PyFloatType::bind(m);
974   PyIndexType::bind(m);
975   PyFloat4E2M1FNType::bind(m);
976   PyFloat6E2M3FNType::bind(m);
977   PyFloat6E3M2FNType::bind(m);
978   PyFloat8E4M3FNType::bind(m);
979   PyFloat8E5M2Type::bind(m);
980   PyFloat8E4M3Type::bind(m);
981   PyFloat8E4M3FNUZType::bind(m);
982   PyFloat8E4M3B11FNUZType::bind(m);
983   PyFloat8E5M2FNUZType::bind(m);
984   PyFloat8E3M4Type::bind(m);
985   PyFloat8E8M0FNUType::bind(m);
986   PyBF16Type::bind(m);
987   PyF16Type::bind(m);
988   PyTF32Type::bind(m);
989   PyF32Type::bind(m);
990   PyF64Type::bind(m);
991   PyNoneType::bind(m);
992   PyComplexType::bind(m);
993   PyShapedType::bind(m);
994   PyVectorType::bind(m);
995   PyRankedTensorType::bind(m);
996   PyUnrankedTensorType::bind(m);
997   PyMemRefType::bind(m);
998   PyUnrankedMemRefType::bind(m);
999   PyTupleType::bind(m);
1000   PyFunctionType::bind(m);
1001   PyOpaqueType::bind(m);
1002 }
1003