1 //===- IRTypes.cpp - Exports builtin and standard types -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 // clang-format off 10 #include "IRModule.h" 11 #include "mlir/Bindings/Python/IRTypes.h" 12 // clang-format on 13 14 #include <optional> 15 16 #include "IRModule.h" 17 #include "NanobindUtils.h" 18 #include "mlir-c/BuiltinAttributes.h" 19 #include "mlir-c/BuiltinTypes.h" 20 #include "mlir-c/Support.h" 21 22 namespace nb = nanobind; 23 using namespace mlir; 24 using namespace mlir::python; 25 26 using llvm::SmallVector; 27 using llvm::Twine; 28 29 namespace { 30 31 /// Checks whether the given type is an integer or float type. 32 static int mlirTypeIsAIntegerOrFloat(MlirType type) { 33 return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || 34 mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); 35 } 36 37 class PyIntegerType : public PyConcreteType<PyIntegerType> { 38 public: 39 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; 40 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 41 mlirIntegerTypeGetTypeID; 42 static constexpr const char *pyClassName = "IntegerType"; 43 using PyConcreteType::PyConcreteType; 44 45 static void bindDerived(ClassTy &c) { 46 c.def_static( 47 "get_signless", 48 [](unsigned width, DefaultingPyMlirContext context) { 49 MlirType t = mlirIntegerTypeGet(context->get(), width); 50 return PyIntegerType(context->getRef(), t); 51 }, 52 nb::arg("width"), nb::arg("context").none() = nb::none(), 53 "Create a signless integer type"); 54 c.def_static( 55 "get_signed", 56 [](unsigned width, DefaultingPyMlirContext context) { 57 MlirType t = mlirIntegerTypeSignedGet(context->get(), width); 58 return PyIntegerType(context->getRef(), t); 59 }, 60 nb::arg("width"), nb::arg("context").none() = nb::none(), 61 "Create a signed integer type"); 62 c.def_static( 63 "get_unsigned", 64 [](unsigned width, DefaultingPyMlirContext context) { 65 MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); 66 return PyIntegerType(context->getRef(), t); 67 }, 68 nb::arg("width"), nb::arg("context").none() = nb::none(), 69 "Create an unsigned integer type"); 70 c.def_prop_ro( 71 "width", 72 [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, 73 "Returns the width of the integer type"); 74 c.def_prop_ro( 75 "is_signless", 76 [](PyIntegerType &self) -> bool { 77 return mlirIntegerTypeIsSignless(self); 78 }, 79 "Returns whether this is a signless integer"); 80 c.def_prop_ro( 81 "is_signed", 82 [](PyIntegerType &self) -> bool { 83 return mlirIntegerTypeIsSigned(self); 84 }, 85 "Returns whether this is a signed integer"); 86 c.def_prop_ro( 87 "is_unsigned", 88 [](PyIntegerType &self) -> bool { 89 return mlirIntegerTypeIsUnsigned(self); 90 }, 91 "Returns whether this is an unsigned integer"); 92 } 93 }; 94 95 /// Index Type subclass - IndexType. 96 class PyIndexType : public PyConcreteType<PyIndexType> { 97 public: 98 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; 99 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 100 mlirIndexTypeGetTypeID; 101 static constexpr const char *pyClassName = "IndexType"; 102 using PyConcreteType::PyConcreteType; 103 104 static void bindDerived(ClassTy &c) { 105 c.def_static( 106 "get", 107 [](DefaultingPyMlirContext context) { 108 MlirType t = mlirIndexTypeGet(context->get()); 109 return PyIndexType(context->getRef(), t); 110 }, 111 nb::arg("context").none() = nb::none(), "Create a index type."); 112 } 113 }; 114 115 class PyFloatType : public PyConcreteType<PyFloatType> { 116 public: 117 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; 118 static constexpr const char *pyClassName = "FloatType"; 119 using PyConcreteType::PyConcreteType; 120 121 static void bindDerived(ClassTy &c) { 122 c.def_prop_ro( 123 "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, 124 "Returns the width of the floating-point type"); 125 } 126 }; 127 128 /// Floating Point Type subclass - Float4E2M1FNType. 129 class PyFloat4E2M1FNType 130 : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> { 131 public: 132 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN; 133 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 134 mlirFloat4E2M1FNTypeGetTypeID; 135 static constexpr const char *pyClassName = "Float4E2M1FNType"; 136 using PyConcreteType::PyConcreteType; 137 138 static void bindDerived(ClassTy &c) { 139 c.def_static( 140 "get", 141 [](DefaultingPyMlirContext context) { 142 MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); 143 return PyFloat4E2M1FNType(context->getRef(), t); 144 }, 145 nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type."); 146 } 147 }; 148 149 /// Floating Point Type subclass - Float6E2M3FNType. 150 class PyFloat6E2M3FNType 151 : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> { 152 public: 153 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN; 154 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 155 mlirFloat6E2M3FNTypeGetTypeID; 156 static constexpr const char *pyClassName = "Float6E2M3FNType"; 157 using PyConcreteType::PyConcreteType; 158 159 static void bindDerived(ClassTy &c) { 160 c.def_static( 161 "get", 162 [](DefaultingPyMlirContext context) { 163 MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); 164 return PyFloat6E2M3FNType(context->getRef(), t); 165 }, 166 nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type."); 167 } 168 }; 169 170 /// Floating Point Type subclass - Float6E3M2FNType. 171 class PyFloat6E3M2FNType 172 : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> { 173 public: 174 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN; 175 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 176 mlirFloat6E3M2FNTypeGetTypeID; 177 static constexpr const char *pyClassName = "Float6E3M2FNType"; 178 using PyConcreteType::PyConcreteType; 179 180 static void bindDerived(ClassTy &c) { 181 c.def_static( 182 "get", 183 [](DefaultingPyMlirContext context) { 184 MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); 185 return PyFloat6E3M2FNType(context->getRef(), t); 186 }, 187 nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type."); 188 } 189 }; 190 191 /// Floating Point Type subclass - Float8E4M3FNType. 192 class PyFloat8E4M3FNType 193 : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> { 194 public: 195 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; 196 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 197 mlirFloat8E4M3FNTypeGetTypeID; 198 static constexpr const char *pyClassName = "Float8E4M3FNType"; 199 using PyConcreteType::PyConcreteType; 200 201 static void bindDerived(ClassTy &c) { 202 c.def_static( 203 "get", 204 [](DefaultingPyMlirContext context) { 205 MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); 206 return PyFloat8E4M3FNType(context->getRef(), t); 207 }, 208 nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type."); 209 } 210 }; 211 212 /// Floating Point Type subclass - Float8E5M2Type. 213 class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> { 214 public: 215 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; 216 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 217 mlirFloat8E5M2TypeGetTypeID; 218 static constexpr const char *pyClassName = "Float8E5M2Type"; 219 using PyConcreteType::PyConcreteType; 220 221 static void bindDerived(ClassTy &c) { 222 c.def_static( 223 "get", 224 [](DefaultingPyMlirContext context) { 225 MlirType t = mlirFloat8E5M2TypeGet(context->get()); 226 return PyFloat8E5M2Type(context->getRef(), t); 227 }, 228 nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type."); 229 } 230 }; 231 232 /// Floating Point Type subclass - Float8E4M3Type. 233 class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> { 234 public: 235 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3; 236 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 237 mlirFloat8E4M3TypeGetTypeID; 238 static constexpr const char *pyClassName = "Float8E4M3Type"; 239 using PyConcreteType::PyConcreteType; 240 241 static void bindDerived(ClassTy &c) { 242 c.def_static( 243 "get", 244 [](DefaultingPyMlirContext context) { 245 MlirType t = mlirFloat8E4M3TypeGet(context->get()); 246 return PyFloat8E4M3Type(context->getRef(), t); 247 }, 248 nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type."); 249 } 250 }; 251 252 /// Floating Point Type subclass - Float8E4M3FNUZ. 253 class PyFloat8E4M3FNUZType 254 : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> { 255 public: 256 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; 257 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 258 mlirFloat8E4M3FNUZTypeGetTypeID; 259 static constexpr const char *pyClassName = "Float8E4M3FNUZType"; 260 using PyConcreteType::PyConcreteType; 261 262 static void bindDerived(ClassTy &c) { 263 c.def_static( 264 "get", 265 [](DefaultingPyMlirContext context) { 266 MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); 267 return PyFloat8E4M3FNUZType(context->getRef(), t); 268 }, 269 nb::arg("context").none() = nb::none(), 270 "Create a float8_e4m3fnuz type."); 271 } 272 }; 273 274 /// Floating Point Type subclass - Float8E4M3B11FNUZ. 275 class PyFloat8E4M3B11FNUZType 276 : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> { 277 public: 278 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; 279 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 280 mlirFloat8E4M3B11FNUZTypeGetTypeID; 281 static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; 282 using PyConcreteType::PyConcreteType; 283 284 static void bindDerived(ClassTy &c) { 285 c.def_static( 286 "get", 287 [](DefaultingPyMlirContext context) { 288 MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); 289 return PyFloat8E4M3B11FNUZType(context->getRef(), t); 290 }, 291 nb::arg("context").none() = nb::none(), 292 "Create a float8_e4m3b11fnuz type."); 293 } 294 }; 295 296 /// Floating Point Type subclass - Float8E5M2FNUZ. 297 class PyFloat8E5M2FNUZType 298 : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> { 299 public: 300 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; 301 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 302 mlirFloat8E5M2FNUZTypeGetTypeID; 303 static constexpr const char *pyClassName = "Float8E5M2FNUZType"; 304 using PyConcreteType::PyConcreteType; 305 306 static void bindDerived(ClassTy &c) { 307 c.def_static( 308 "get", 309 [](DefaultingPyMlirContext context) { 310 MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); 311 return PyFloat8E5M2FNUZType(context->getRef(), t); 312 }, 313 nb::arg("context").none() = nb::none(), 314 "Create a float8_e5m2fnuz type."); 315 } 316 }; 317 318 /// Floating Point Type subclass - Float8E3M4Type. 319 class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> { 320 public: 321 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4; 322 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 323 mlirFloat8E3M4TypeGetTypeID; 324 static constexpr const char *pyClassName = "Float8E3M4Type"; 325 using PyConcreteType::PyConcreteType; 326 327 static void bindDerived(ClassTy &c) { 328 c.def_static( 329 "get", 330 [](DefaultingPyMlirContext context) { 331 MlirType t = mlirFloat8E3M4TypeGet(context->get()); 332 return PyFloat8E3M4Type(context->getRef(), t); 333 }, 334 nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type."); 335 } 336 }; 337 338 /// Floating Point Type subclass - Float8E8M0FNUType. 339 class PyFloat8E8M0FNUType 340 : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> { 341 public: 342 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU; 343 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 344 mlirFloat8E8M0FNUTypeGetTypeID; 345 static constexpr const char *pyClassName = "Float8E8M0FNUType"; 346 using PyConcreteType::PyConcreteType; 347 348 static void bindDerived(ClassTy &c) { 349 c.def_static( 350 "get", 351 [](DefaultingPyMlirContext context) { 352 MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); 353 return PyFloat8E8M0FNUType(context->getRef(), t); 354 }, 355 nb::arg("context").none() = nb::none(), 356 "Create a float8_e8m0fnu type."); 357 } 358 }; 359 360 /// Floating Point Type subclass - BF16Type. 361 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> { 362 public: 363 static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; 364 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 365 mlirBFloat16TypeGetTypeID; 366 static constexpr const char *pyClassName = "BF16Type"; 367 using PyConcreteType::PyConcreteType; 368 369 static void bindDerived(ClassTy &c) { 370 c.def_static( 371 "get", 372 [](DefaultingPyMlirContext context) { 373 MlirType t = mlirBF16TypeGet(context->get()); 374 return PyBF16Type(context->getRef(), t); 375 }, 376 nb::arg("context").none() = nb::none(), "Create a bf16 type."); 377 } 378 }; 379 380 /// Floating Point Type subclass - F16Type. 381 class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> { 382 public: 383 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; 384 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 385 mlirFloat16TypeGetTypeID; 386 static constexpr const char *pyClassName = "F16Type"; 387 using PyConcreteType::PyConcreteType; 388 389 static void bindDerived(ClassTy &c) { 390 c.def_static( 391 "get", 392 [](DefaultingPyMlirContext context) { 393 MlirType t = mlirF16TypeGet(context->get()); 394 return PyF16Type(context->getRef(), t); 395 }, 396 nb::arg("context").none() = nb::none(), "Create a f16 type."); 397 } 398 }; 399 400 /// Floating Point Type subclass - TF32Type. 401 class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> { 402 public: 403 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; 404 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 405 mlirFloatTF32TypeGetTypeID; 406 static constexpr const char *pyClassName = "FloatTF32Type"; 407 using PyConcreteType::PyConcreteType; 408 409 static void bindDerived(ClassTy &c) { 410 c.def_static( 411 "get", 412 [](DefaultingPyMlirContext context) { 413 MlirType t = mlirTF32TypeGet(context->get()); 414 return PyTF32Type(context->getRef(), t); 415 }, 416 nb::arg("context").none() = nb::none(), "Create a tf32 type."); 417 } 418 }; 419 420 /// Floating Point Type subclass - F32Type. 421 class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> { 422 public: 423 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; 424 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 425 mlirFloat32TypeGetTypeID; 426 static constexpr const char *pyClassName = "F32Type"; 427 using PyConcreteType::PyConcreteType; 428 429 static void bindDerived(ClassTy &c) { 430 c.def_static( 431 "get", 432 [](DefaultingPyMlirContext context) { 433 MlirType t = mlirF32TypeGet(context->get()); 434 return PyF32Type(context->getRef(), t); 435 }, 436 nb::arg("context").none() = nb::none(), "Create a f32 type."); 437 } 438 }; 439 440 /// Floating Point Type subclass - F64Type. 441 class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> { 442 public: 443 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; 444 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 445 mlirFloat64TypeGetTypeID; 446 static constexpr const char *pyClassName = "F64Type"; 447 using PyConcreteType::PyConcreteType; 448 449 static void bindDerived(ClassTy &c) { 450 c.def_static( 451 "get", 452 [](DefaultingPyMlirContext context) { 453 MlirType t = mlirF64TypeGet(context->get()); 454 return PyF64Type(context->getRef(), t); 455 }, 456 nb::arg("context").none() = nb::none(), "Create a f64 type."); 457 } 458 }; 459 460 /// None Type subclass - NoneType. 461 class PyNoneType : public PyConcreteType<PyNoneType> { 462 public: 463 static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; 464 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 465 mlirNoneTypeGetTypeID; 466 static constexpr const char *pyClassName = "NoneType"; 467 using PyConcreteType::PyConcreteType; 468 469 static void bindDerived(ClassTy &c) { 470 c.def_static( 471 "get", 472 [](DefaultingPyMlirContext context) { 473 MlirType t = mlirNoneTypeGet(context->get()); 474 return PyNoneType(context->getRef(), t); 475 }, 476 nb::arg("context").none() = nb::none(), "Create a none type."); 477 } 478 }; 479 480 /// Complex Type subclass - ComplexType. 481 class PyComplexType : public PyConcreteType<PyComplexType> { 482 public: 483 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; 484 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 485 mlirComplexTypeGetTypeID; 486 static constexpr const char *pyClassName = "ComplexType"; 487 using PyConcreteType::PyConcreteType; 488 489 static void bindDerived(ClassTy &c) { 490 c.def_static( 491 "get", 492 [](PyType &elementType) { 493 // The element must be a floating point or integer scalar type. 494 if (mlirTypeIsAIntegerOrFloat(elementType)) { 495 MlirType t = mlirComplexTypeGet(elementType); 496 return PyComplexType(elementType.getContext(), t); 497 } 498 throw nb::value_error( 499 (Twine("invalid '") + 500 nb::cast<std::string>(nb::repr(nb::cast(elementType))) + 501 "' and expected floating point or integer type.") 502 .str() 503 .c_str()); 504 }, 505 "Create a complex type"); 506 c.def_prop_ro( 507 "element_type", 508 [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, 509 "Returns element type."); 510 } 511 }; 512 513 } // namespace 514 515 // Shaped Type Interface - ShapedType 516 void mlir::PyShapedType::bindDerived(ClassTy &c) { 517 c.def_prop_ro( 518 "element_type", 519 [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, 520 "Returns the element type of the shaped type."); 521 c.def_prop_ro( 522 "has_rank", 523 [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, 524 "Returns whether the given shaped type is ranked."); 525 c.def_prop_ro( 526 "rank", 527 [](PyShapedType &self) { 528 self.requireHasRank(); 529 return mlirShapedTypeGetRank(self); 530 }, 531 "Returns the rank of the given ranked shaped type."); 532 c.def_prop_ro( 533 "has_static_shape", 534 [](PyShapedType &self) -> bool { 535 return mlirShapedTypeHasStaticShape(self); 536 }, 537 "Returns whether the given shaped type has a static shape."); 538 c.def( 539 "is_dynamic_dim", 540 [](PyShapedType &self, intptr_t dim) -> bool { 541 self.requireHasRank(); 542 return mlirShapedTypeIsDynamicDim(self, dim); 543 }, 544 nb::arg("dim"), 545 "Returns whether the dim-th dimension of the given shaped type is " 546 "dynamic."); 547 c.def( 548 "get_dim_size", 549 [](PyShapedType &self, intptr_t dim) { 550 self.requireHasRank(); 551 return mlirShapedTypeGetDimSize(self, dim); 552 }, 553 nb::arg("dim"), 554 "Returns the dim-th dimension of the given ranked shaped type."); 555 c.def_static( 556 "is_dynamic_size", 557 [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, 558 nb::arg("dim_size"), 559 "Returns whether the given dimension size indicates a dynamic " 560 "dimension."); 561 c.def( 562 "is_dynamic_stride_or_offset", 563 [](PyShapedType &self, int64_t val) -> bool { 564 self.requireHasRank(); 565 return mlirShapedTypeIsDynamicStrideOrOffset(val); 566 }, 567 nb::arg("dim_size"), 568 "Returns whether the given value is used as a placeholder for dynamic " 569 "strides and offsets in shaped types."); 570 c.def_prop_ro( 571 "shape", 572 [](PyShapedType &self) { 573 self.requireHasRank(); 574 575 std::vector<int64_t> shape; 576 int64_t rank = mlirShapedTypeGetRank(self); 577 shape.reserve(rank); 578 for (int64_t i = 0; i < rank; ++i) 579 shape.push_back(mlirShapedTypeGetDimSize(self, i)); 580 return shape; 581 }, 582 "Returns the shape of the ranked shaped type as a list of integers."); 583 c.def_static( 584 "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, 585 "Returns the value used to indicate dynamic dimensions in shaped " 586 "types."); 587 c.def_static( 588 "get_dynamic_stride_or_offset", 589 []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, 590 "Returns the value used to indicate dynamic strides or offsets in " 591 "shaped types."); 592 } 593 594 void mlir::PyShapedType::requireHasRank() { 595 if (!mlirShapedTypeHasRank(*this)) { 596 throw nb::value_error( 597 "calling this method requires that the type has a rank."); 598 } 599 } 600 601 const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction = 602 mlirTypeIsAShaped; 603 604 namespace { 605 606 /// Vector Type subclass - VectorType. 607 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> { 608 public: 609 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; 610 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 611 mlirVectorTypeGetTypeID; 612 static constexpr const char *pyClassName = "VectorType"; 613 using PyConcreteType::PyConcreteType; 614 615 static void bindDerived(ClassTy &c) { 616 c.def_static("get", &PyVectorType::get, nb::arg("shape"), 617 nb::arg("element_type"), nb::kw_only(), 618 nb::arg("scalable").none() = nb::none(), 619 nb::arg("scalable_dims").none() = nb::none(), 620 nb::arg("loc").none() = nb::none(), "Create a vector type") 621 .def_prop_ro( 622 "scalable", 623 [](MlirType self) { return mlirVectorTypeIsScalable(self); }) 624 .def_prop_ro("scalable_dims", [](MlirType self) { 625 std::vector<bool> scalableDims; 626 size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self)); 627 scalableDims.reserve(rank); 628 for (size_t i = 0; i < rank; ++i) 629 scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i)); 630 return scalableDims; 631 }); 632 } 633 634 private: 635 static PyVectorType get(std::vector<int64_t> shape, PyType &elementType, 636 std::optional<nb::list> scalable, 637 std::optional<std::vector<int64_t>> scalableDims, 638 DefaultingPyLocation loc) { 639 if (scalable && scalableDims) { 640 throw nb::value_error("'scalable' and 'scalable_dims' kwargs " 641 "are mutually exclusive."); 642 } 643 644 PyMlirContext::ErrorCapture errors(loc->getContext()); 645 MlirType type; 646 if (scalable) { 647 if (scalable->size() != shape.size()) 648 throw nb::value_error("Expected len(scalable) == len(shape)."); 649 650 SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range( 651 *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); })); 652 type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), 653 scalableDimFlags.data(), 654 elementType); 655 } else if (scalableDims) { 656 SmallVector<bool> scalableDimFlags(shape.size(), false); 657 for (int64_t dim : *scalableDims) { 658 if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) 659 throw nb::value_error("Scalable dimension index out of bounds."); 660 scalableDimFlags[dim] = true; 661 } 662 type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), 663 scalableDimFlags.data(), 664 elementType); 665 } else { 666 type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), 667 elementType); 668 } 669 if (mlirTypeIsNull(type)) 670 throw MLIRError("Invalid type", errors.take()); 671 return PyVectorType(elementType.getContext(), type); 672 } 673 }; 674 675 /// Ranked Tensor Type subclass - RankedTensorType. 676 class PyRankedTensorType 677 : public PyConcreteType<PyRankedTensorType, PyShapedType> { 678 public: 679 static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; 680 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 681 mlirRankedTensorTypeGetTypeID; 682 static constexpr const char *pyClassName = "RankedTensorType"; 683 using PyConcreteType::PyConcreteType; 684 685 static void bindDerived(ClassTy &c) { 686 c.def_static( 687 "get", 688 [](std::vector<int64_t> shape, PyType &elementType, 689 std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) { 690 PyMlirContext::ErrorCapture errors(loc->getContext()); 691 MlirType t = mlirRankedTensorTypeGetChecked( 692 loc, shape.size(), shape.data(), elementType, 693 encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); 694 if (mlirTypeIsNull(t)) 695 throw MLIRError("Invalid type", errors.take()); 696 return PyRankedTensorType(elementType.getContext(), t); 697 }, 698 nb::arg("shape"), nb::arg("element_type"), 699 nb::arg("encoding").none() = nb::none(), 700 nb::arg("loc").none() = nb::none(), "Create a ranked tensor type"); 701 c.def_prop_ro("encoding", 702 [](PyRankedTensorType &self) -> std::optional<MlirAttribute> { 703 MlirAttribute encoding = 704 mlirRankedTensorTypeGetEncoding(self.get()); 705 if (mlirAttributeIsNull(encoding)) 706 return std::nullopt; 707 return encoding; 708 }); 709 } 710 }; 711 712 /// Unranked Tensor Type subclass - UnrankedTensorType. 713 class PyUnrankedTensorType 714 : public PyConcreteType<PyUnrankedTensorType, PyShapedType> { 715 public: 716 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; 717 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 718 mlirUnrankedTensorTypeGetTypeID; 719 static constexpr const char *pyClassName = "UnrankedTensorType"; 720 using PyConcreteType::PyConcreteType; 721 722 static void bindDerived(ClassTy &c) { 723 c.def_static( 724 "get", 725 [](PyType &elementType, DefaultingPyLocation loc) { 726 PyMlirContext::ErrorCapture errors(loc->getContext()); 727 MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); 728 if (mlirTypeIsNull(t)) 729 throw MLIRError("Invalid type", errors.take()); 730 return PyUnrankedTensorType(elementType.getContext(), t); 731 }, 732 nb::arg("element_type"), nb::arg("loc").none() = nb::none(), 733 "Create a unranked tensor type"); 734 } 735 }; 736 737 /// Ranked MemRef Type subclass - MemRefType. 738 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> { 739 public: 740 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; 741 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 742 mlirMemRefTypeGetTypeID; 743 static constexpr const char *pyClassName = "MemRefType"; 744 using PyConcreteType::PyConcreteType; 745 746 static void bindDerived(ClassTy &c) { 747 c.def_static( 748 "get", 749 [](std::vector<int64_t> shape, PyType &elementType, 750 PyAttribute *layout, PyAttribute *memorySpace, 751 DefaultingPyLocation loc) { 752 PyMlirContext::ErrorCapture errors(loc->getContext()); 753 MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); 754 MlirAttribute memSpaceAttr = 755 memorySpace ? *memorySpace : mlirAttributeGetNull(); 756 MlirType t = 757 mlirMemRefTypeGetChecked(loc, elementType, shape.size(), 758 shape.data(), layoutAttr, memSpaceAttr); 759 if (mlirTypeIsNull(t)) 760 throw MLIRError("Invalid type", errors.take()); 761 return PyMemRefType(elementType.getContext(), t); 762 }, 763 nb::arg("shape"), nb::arg("element_type"), 764 nb::arg("layout").none() = nb::none(), 765 nb::arg("memory_space").none() = nb::none(), 766 nb::arg("loc").none() = nb::none(), "Create a memref type") 767 .def_prop_ro( 768 "layout", 769 [](PyMemRefType &self) -> MlirAttribute { 770 return mlirMemRefTypeGetLayout(self); 771 }, 772 "The layout of the MemRef type.") 773 .def( 774 "get_strides_and_offset", 775 [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> { 776 std::vector<int64_t> strides(mlirShapedTypeGetRank(self)); 777 int64_t offset; 778 if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset( 779 self, strides.data(), &offset))) 780 throw std::runtime_error( 781 "Failed to extract strides and offset from memref."); 782 return {strides, offset}; 783 }, 784 "The strides and offset of the MemRef type.") 785 .def_prop_ro( 786 "affine_map", 787 [](PyMemRefType &self) -> PyAffineMap { 788 MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); 789 return PyAffineMap(self.getContext(), map); 790 }, 791 "The layout of the MemRef type as an affine map.") 792 .def_prop_ro( 793 "memory_space", 794 [](PyMemRefType &self) -> std::optional<MlirAttribute> { 795 MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); 796 if (mlirAttributeIsNull(a)) 797 return std::nullopt; 798 return a; 799 }, 800 "Returns the memory space of the given MemRef type."); 801 } 802 }; 803 804 /// Unranked MemRef Type subclass - UnrankedMemRefType. 805 class PyUnrankedMemRefType 806 : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> { 807 public: 808 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; 809 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 810 mlirUnrankedMemRefTypeGetTypeID; 811 static constexpr const char *pyClassName = "UnrankedMemRefType"; 812 using PyConcreteType::PyConcreteType; 813 814 static void bindDerived(ClassTy &c) { 815 c.def_static( 816 "get", 817 [](PyType &elementType, PyAttribute *memorySpace, 818 DefaultingPyLocation loc) { 819 PyMlirContext::ErrorCapture errors(loc->getContext()); 820 MlirAttribute memSpaceAttr = {}; 821 if (memorySpace) 822 memSpaceAttr = *memorySpace; 823 824 MlirType t = 825 mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); 826 if (mlirTypeIsNull(t)) 827 throw MLIRError("Invalid type", errors.take()); 828 return PyUnrankedMemRefType(elementType.getContext(), t); 829 }, 830 nb::arg("element_type"), nb::arg("memory_space").none(), 831 nb::arg("loc").none() = nb::none(), "Create a unranked memref type") 832 .def_prop_ro( 833 "memory_space", 834 [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> { 835 MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); 836 if (mlirAttributeIsNull(a)) 837 return std::nullopt; 838 return a; 839 }, 840 "Returns the memory space of the given Unranked MemRef type."); 841 } 842 }; 843 844 /// Tuple Type subclass - TupleType. 845 class PyTupleType : public PyConcreteType<PyTupleType> { 846 public: 847 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; 848 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 849 mlirTupleTypeGetTypeID; 850 static constexpr const char *pyClassName = "TupleType"; 851 using PyConcreteType::PyConcreteType; 852 853 static void bindDerived(ClassTy &c) { 854 c.def_static( 855 "get_tuple", 856 [](std::vector<MlirType> elements, DefaultingPyMlirContext context) { 857 MlirType t = mlirTupleTypeGet(context->get(), elements.size(), 858 elements.data()); 859 return PyTupleType(context->getRef(), t); 860 }, 861 nb::arg("elements"), nb::arg("context").none() = nb::none(), 862 "Create a tuple type"); 863 c.def( 864 "get_type", 865 [](PyTupleType &self, intptr_t pos) { 866 return mlirTupleTypeGetType(self, pos); 867 }, 868 nb::arg("pos"), "Returns the pos-th type in the tuple type."); 869 c.def_prop_ro( 870 "num_types", 871 [](PyTupleType &self) -> intptr_t { 872 return mlirTupleTypeGetNumTypes(self); 873 }, 874 "Returns the number of types contained in a tuple."); 875 } 876 }; 877 878 /// Function type. 879 class PyFunctionType : public PyConcreteType<PyFunctionType> { 880 public: 881 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; 882 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 883 mlirFunctionTypeGetTypeID; 884 static constexpr const char *pyClassName = "FunctionType"; 885 using PyConcreteType::PyConcreteType; 886 887 static void bindDerived(ClassTy &c) { 888 c.def_static( 889 "get", 890 [](std::vector<MlirType> inputs, std::vector<MlirType> results, 891 DefaultingPyMlirContext context) { 892 MlirType t = 893 mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), 894 results.size(), results.data()); 895 return PyFunctionType(context->getRef(), t); 896 }, 897 nb::arg("inputs"), nb::arg("results"), 898 nb::arg("context").none() = nb::none(), 899 "Gets a FunctionType from a list of input and result types"); 900 c.def_prop_ro( 901 "inputs", 902 [](PyFunctionType &self) { 903 MlirType t = self; 904 nb::list types; 905 for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; 906 ++i) { 907 types.append(mlirFunctionTypeGetInput(t, i)); 908 } 909 return types; 910 }, 911 "Returns the list of input types in the FunctionType."); 912 c.def_prop_ro( 913 "results", 914 [](PyFunctionType &self) { 915 nb::list types; 916 for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; 917 ++i) { 918 types.append(mlirFunctionTypeGetResult(self, i)); 919 } 920 return types; 921 }, 922 "Returns the list of result types in the FunctionType."); 923 } 924 }; 925 926 static MlirStringRef toMlirStringRef(const std::string &s) { 927 return mlirStringRefCreate(s.data(), s.size()); 928 } 929 930 /// Opaque Type subclass - OpaqueType. 931 class PyOpaqueType : public PyConcreteType<PyOpaqueType> { 932 public: 933 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; 934 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 935 mlirOpaqueTypeGetTypeID; 936 static constexpr const char *pyClassName = "OpaqueType"; 937 using PyConcreteType::PyConcreteType; 938 939 static void bindDerived(ClassTy &c) { 940 c.def_static( 941 "get", 942 [](std::string dialectNamespace, std::string typeData, 943 DefaultingPyMlirContext context) { 944 MlirType type = mlirOpaqueTypeGet(context->get(), 945 toMlirStringRef(dialectNamespace), 946 toMlirStringRef(typeData)); 947 return PyOpaqueType(context->getRef(), type); 948 }, 949 nb::arg("dialect_namespace"), nb::arg("buffer"), 950 nb::arg("context").none() = nb::none(), 951 "Create an unregistered (opaque) dialect type."); 952 c.def_prop_ro( 953 "dialect_namespace", 954 [](PyOpaqueType &self) { 955 MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); 956 return nb::str(stringRef.data, stringRef.length); 957 }, 958 "Returns the dialect namespace for the Opaque type as a string."); 959 c.def_prop_ro( 960 "data", 961 [](PyOpaqueType &self) { 962 MlirStringRef stringRef = mlirOpaqueTypeGetData(self); 963 return nb::str(stringRef.data, stringRef.length); 964 }, 965 "Returns the data for the Opaque type as a string."); 966 } 967 }; 968 969 } // namespace 970 971 void mlir::python::populateIRTypes(nb::module_ &m) { 972 PyIntegerType::bind(m); 973 PyFloatType::bind(m); 974 PyIndexType::bind(m); 975 PyFloat4E2M1FNType::bind(m); 976 PyFloat6E2M3FNType::bind(m); 977 PyFloat6E3M2FNType::bind(m); 978 PyFloat8E4M3FNType::bind(m); 979 PyFloat8E5M2Type::bind(m); 980 PyFloat8E4M3Type::bind(m); 981 PyFloat8E4M3FNUZType::bind(m); 982 PyFloat8E4M3B11FNUZType::bind(m); 983 PyFloat8E5M2FNUZType::bind(m); 984 PyFloat8E3M4Type::bind(m); 985 PyFloat8E8M0FNUType::bind(m); 986 PyBF16Type::bind(m); 987 PyF16Type::bind(m); 988 PyTF32Type::bind(m); 989 PyF32Type::bind(m); 990 PyF64Type::bind(m); 991 PyNoneType::bind(m); 992 PyComplexType::bind(m); 993 PyShapedType::bind(m); 994 PyVectorType::bind(m); 995 PyRankedTensorType::bind(m); 996 PyUnrankedTensorType::bind(m); 997 PyMemRefType::bind(m); 998 PyUnrankedMemRefType::bind(m); 999 PyTupleType::bind(m); 1000 PyFunctionType::bind(m); 1001 PyOpaqueType::bind(m); 1002 } 1003