1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5from mlir.dialects import arith, tensor, func, memref 6import mlir.extras.types as T 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 gc.collect() 13 assert Context._get_live_count() == 0 14 return f 15 16 17# CHECK-LABEL: TEST: testParsePrint 18@run 19def testParsePrint(): 20 ctx = Context() 21 t = Type.parse("i32", ctx) 22 assert t.context is ctx 23 ctx = None 24 gc.collect() 25 # CHECK: i32 26 print(str(t)) 27 # CHECK: Type(i32) 28 print(repr(t)) 29 30 31# CHECK-LABEL: TEST: testParseError 32@run 33def testParseError(): 34 ctx = Context() 35 try: 36 t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx) 37 except MLIRError as e: 38 # CHECK: testParseError: < 39 # CHECK: Unable to parse type: 40 # CHECK: error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type 41 # CHECK: > 42 print(f"testParseError: <{e}>") 43 else: 44 print("Exception not produced") 45 46 47# CHECK-LABEL: TEST: testTypeEq 48@run 49def testTypeEq(): 50 ctx = Context() 51 t1 = Type.parse("i32", ctx) 52 t2 = Type.parse("f32", ctx) 53 t3 = Type.parse("i32", ctx) 54 # CHECK: t1 == t1: True 55 print("t1 == t1:", t1 == t1) 56 # CHECK: t1 == t2: False 57 print("t1 == t2:", t1 == t2) 58 # CHECK: t1 == t3: True 59 print("t1 == t3:", t1 == t3) 60 # CHECK: t1 is None: False 61 print("t1 is None:", t1 is None) 62 63 64# CHECK-LABEL: TEST: testTypeHash 65@run 66def testTypeHash(): 67 ctx = Context() 68 t1 = Type.parse("i32", ctx) 69 t2 = Type.parse("f32", ctx) 70 t3 = Type.parse("i32", ctx) 71 72 # CHECK: hash(t1) == hash(t3): True 73 print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__()) 74 75 s = set() 76 s.add(t1) 77 s.add(t2) 78 s.add(t3) 79 # CHECK: len(s): 2 80 print("len(s): ", len(s)) 81 82 83# CHECK-LABEL: TEST: testTypeCast 84@run 85def testTypeCast(): 86 ctx = Context() 87 t1 = Type.parse("i32", ctx) 88 t2 = Type(t1) 89 # CHECK: t1 == t2: True 90 print("t1 == t2:", t1 == t2) 91 92 93# CHECK-LABEL: TEST: testTypeIsInstance 94@run 95def testTypeIsInstance(): 96 ctx = Context() 97 t1 = Type.parse("i32", ctx) 98 t2 = Type.parse("f32", ctx) 99 # CHECK: True 100 print(IntegerType.isinstance(t1)) 101 # CHECK: False 102 print(F32Type.isinstance(t1)) 103 # CHECK: False 104 print(FloatType.isinstance(t1)) 105 # CHECK: True 106 print(F32Type.isinstance(t2)) 107 # CHECK: True 108 print(FloatType.isinstance(t2)) 109 110 111# CHECK-LABEL: TEST: testFloatTypeSubclasses 112@run 113def testFloatTypeSubclasses(): 114 ctx = Context() 115 # CHECK: True 116 print(isinstance(Type.parse("f4E2M1FN", ctx), FloatType)) 117 # CHECK: True 118 print(isinstance(Type.parse("f6E2M3FN", ctx), FloatType)) 119 # CHECK: True 120 print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType)) 121 # CHECK: True 122 print(isinstance(Type.parse("f8E3M4", ctx), FloatType)) 123 # CHECK: True 124 print(isinstance(Type.parse("f8E4M3", ctx), FloatType)) 125 # CHECK: True 126 print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType)) 127 # CHECK: True 128 print(isinstance(Type.parse("f8E5M2", ctx), FloatType)) 129 # CHECK: True 130 print(isinstance(Type.parse("f8E4M3FNUZ", ctx), FloatType)) 131 # CHECK: True 132 print(isinstance(Type.parse("f8E4M3B11FNUZ", ctx), FloatType)) 133 # CHECK: True 134 print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType)) 135 # CHECK: True 136 print(isinstance(Type.parse("f8E8M0FNU", ctx), FloatType)) 137 # CHECK: True 138 print(isinstance(Type.parse("f16", ctx), FloatType)) 139 # CHECK: True 140 print(isinstance(Type.parse("bf16", ctx), FloatType)) 141 # CHECK: True 142 print(isinstance(Type.parse("f32", ctx), FloatType)) 143 # CHECK: True 144 print(isinstance(Type.parse("tf32", ctx), FloatType)) 145 # CHECK: True 146 print(isinstance(Type.parse("f64", ctx), FloatType)) 147 148 149# CHECK-LABEL: TEST: testTypeEqDoesNotRaise 150@run 151def testTypeEqDoesNotRaise(): 152 ctx = Context() 153 t1 = Type.parse("i32", ctx) 154 not_a_type = "foo" 155 # CHECK: False 156 print(t1 == not_a_type) 157 # CHECK: False 158 print(t1 is None) 159 # CHECK: True 160 print(t1 is not None) 161 162 163# CHECK-LABEL: TEST: testTypeCapsule 164@run 165def testTypeCapsule(): 166 with Context() as ctx: 167 t1 = Type.parse("i32", ctx) 168 # CHECK: mlir.ir.Type._CAPIPtr 169 type_capsule = t1._CAPIPtr 170 print(type_capsule) 171 t2 = Type._CAPICreate(type_capsule) 172 assert t2 == t1 173 assert t2.context is ctx 174 175 176# CHECK-LABEL: TEST: testStandardTypeCasts 177@run 178def testStandardTypeCasts(): 179 ctx = Context() 180 t1 = Type.parse("i32", ctx) 181 tint = IntegerType(t1) 182 tself = IntegerType(tint) 183 # CHECK: Type(i32) 184 print(repr(tint)) 185 try: 186 tillegal = IntegerType(Type.parse("f32", ctx)) 187 except ValueError as e: 188 # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32)) 189 print("ValueError:", e) 190 else: 191 print("Exception not produced") 192 193 194# CHECK-LABEL: TEST: testIntegerType 195@run 196def testIntegerType(): 197 with Context() as ctx: 198 i32 = IntegerType(Type.parse("i32")) 199 # CHECK: i32 width: 32 200 print("i32 width:", i32.width) 201 # CHECK: i32 signless: True 202 print("i32 signless:", i32.is_signless) 203 # CHECK: i32 signed: False 204 print("i32 signed:", i32.is_signed) 205 # CHECK: i32 unsigned: False 206 print("i32 unsigned:", i32.is_unsigned) 207 208 s32 = IntegerType(Type.parse("si32")) 209 # CHECK: s32 signless: False 210 print("s32 signless:", s32.is_signless) 211 # CHECK: s32 signed: True 212 print("s32 signed:", s32.is_signed) 213 # CHECK: s32 unsigned: False 214 print("s32 unsigned:", s32.is_unsigned) 215 216 u32 = IntegerType(Type.parse("ui32")) 217 # CHECK: u32 signless: False 218 print("u32 signless:", u32.is_signless) 219 # CHECK: u32 signed: False 220 print("u32 signed:", u32.is_signed) 221 # CHECK: u32 unsigned: True 222 print("u32 unsigned:", u32.is_unsigned) 223 224 # CHECK: signless: i16 225 print("signless:", IntegerType.get_signless(16)) 226 # CHECK: signed: si8 227 print("signed:", IntegerType.get_signed(8)) 228 # CHECK: unsigned: ui64 229 print("unsigned:", IntegerType.get_unsigned(64)) 230 231 232# CHECK-LABEL: TEST: testIndexType 233@run 234def testIndexType(): 235 with Context() as ctx: 236 # CHECK: index type: index 237 print("index type:", IndexType.get()) 238 239 240# CHECK-LABEL: TEST: testFloatType 241@run 242def testFloatType(): 243 with Context(): 244 # CHECK: float: f4E2M1FN 245 print("float:", Float4E2M1FNType.get()) 246 # CHECK: float: f6E2M3FN 247 print("float:", Float6E2M3FNType.get()) 248 # CHECK: float: f6E3M2FN 249 print("float:", Float6E3M2FNType.get()) 250 # CHECK: float: f8E3M4 251 print("float:", Float8E3M4Type.get()) 252 # CHECK: float: f8E4M3 253 print("float:", Float8E4M3Type.get()) 254 # CHECK: float: f8E4M3FN 255 print("float:", Float8E4M3FNType.get()) 256 # CHECK: float: f8E5M2 257 print("float:", Float8E5M2Type.get()) 258 # CHECK: float: f8E5M2FNUZ 259 print("float:", Float8E5M2FNUZType.get()) 260 # CHECK: float: f8E4M3FNUZ 261 print("float:", Float8E4M3FNUZType.get()) 262 # CHECK: float: f8E4M3B11FNUZ 263 print("float:", Float8E4M3B11FNUZType.get()) 264 # CHECK: float: f8E8M0FNU 265 print("float:", Float8E8M0FNUType.get()) 266 # CHECK: float: bf16 267 print("float:", BF16Type.get()) 268 # CHECK: float: f16 269 print("float:", F16Type.get()) 270 # CHECK: float: tf32 271 print("float:", FloatTF32Type.get()) 272 # CHECK: float: f32 273 print("float:", F32Type.get()) 274 # CHECK: float: f64 275 f64 = F64Type.get() 276 print("float:", f64) 277 # CHECK: f64 width: 64 278 print("f64 width:", f64.width) 279 280 281# CHECK-LABEL: TEST: testNoneType 282@run 283def testNoneType(): 284 with Context(): 285 # CHECK: none type: none 286 print("none type:", NoneType.get()) 287 288 289# CHECK-LABEL: TEST: testComplexType 290@run 291def testComplexType(): 292 with Context() as ctx: 293 complex_i32 = ComplexType(Type.parse("complex<i32>")) 294 # CHECK: complex type element: i32 295 print("complex type element:", complex_i32.element_type) 296 297 f32 = F32Type.get() 298 # CHECK: complex type: complex<f32> 299 print("complex type:", ComplexType.get(f32)) 300 301 index = IndexType.get() 302 try: 303 complex_invalid = ComplexType.get(index) 304 except ValueError as e: 305 # CHECK: invalid 'Type(index)' and expected floating point or integer type. 306 print(e) 307 else: 308 print("Exception not produced") 309 310 311# CHECK-LABEL: TEST: testConcreteShapedType 312# Shaped type is not a kind of builtin types, it is the base class for vectors, 313# memrefs and tensors, so this test case uses an instance of vector to test the 314# shaped type. The class hierarchy is preserved on the python side. 315@run 316def testConcreteShapedType(): 317 with Context() as ctx: 318 vector = VectorType(Type.parse("vector<2x3xf32>")) 319 # CHECK: element type: f32 320 print("element type:", vector.element_type) 321 # CHECK: whether the given shaped type is ranked: True 322 print("whether the given shaped type is ranked:", vector.has_rank) 323 # CHECK: rank: 2 324 print("rank:", vector.rank) 325 # CHECK: whether the shaped type has a static shape: True 326 print("whether the shaped type has a static shape:", vector.has_static_shape) 327 # CHECK: whether the dim-th dimension is dynamic: False 328 print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) 329 # CHECK: dim size: 3 330 print("dim size:", vector.get_dim_size(1)) 331 # CHECK: is_dynamic_size: False 332 print("is_dynamic_size:", vector.is_dynamic_size(3)) 333 # CHECK: is_dynamic_stride_or_offset: False 334 print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) 335 # CHECK: isinstance(ShapedType): True 336 print("isinstance(ShapedType):", isinstance(vector, ShapedType)) 337 338 339# CHECK-LABEL: TEST: testAbstractShapedType 340# Tests that ShapedType operates as an abstract base class of a concrete 341# shaped type (using vector as an example). 342@run 343def testAbstractShapedType(): 344 ctx = Context() 345 vector = ShapedType(Type.parse("vector<2x3xf32>", ctx)) 346 # CHECK: element type: f32 347 print("element type:", vector.element_type) 348 349 350# CHECK-LABEL: TEST: testVectorType 351@run 352def testVectorType(): 353 with Context(), Location.unknown(): 354 f32 = F32Type.get() 355 shape = [2, 3] 356 # CHECK: vector type: vector<2x3xf32> 357 print("vector type:", VectorType.get(shape, f32)) 358 359 none = NoneType.get() 360 try: 361 VectorType.get(shape, none) 362 except MLIRError as e: 363 # CHECK: Invalid type: 364 # CHECK: error: unknown: failed to verify 'elementType': integer or index or floating-point 365 print(e) 366 else: 367 print("Exception not produced") 368 369 scalable_1 = VectorType.get(shape, f32, scalable=[False, True]) 370 scalable_2 = VectorType.get([2, 3, 4], f32, scalable=[True, False, True]) 371 assert scalable_1.scalable 372 assert scalable_2.scalable 373 assert scalable_1.scalable_dims == [False, True] 374 assert scalable_2.scalable_dims == [True, False, True] 375 # CHECK: scalable 1: vector<2x[3]xf32> 376 print("scalable 1: ", scalable_1) 377 # CHECK: scalable 2: vector<[2]x3x[4]xf32> 378 print("scalable 2: ", scalable_2) 379 380 scalable_3 = VectorType.get(shape, f32, scalable_dims=[1]) 381 scalable_4 = VectorType.get([2, 3, 4], f32, scalable_dims=[0, 2]) 382 assert scalable_3 == scalable_1 383 assert scalable_4 == scalable_2 384 385 try: 386 VectorType.get(shape, f32, scalable=[False, True, True]) 387 except ValueError as e: 388 # CHECK: Expected len(scalable) == len(shape). 389 print(e) 390 else: 391 print("Exception not produced") 392 393 try: 394 VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[1]) 395 except ValueError as e: 396 # CHECK: kwargs are mutually exclusive. 397 print(e) 398 else: 399 print("Exception not produced") 400 401 try: 402 VectorType.get(shape, f32, scalable_dims=[42]) 403 except ValueError as e: 404 # CHECK: Scalable dimension index out of bounds. 405 print(e) 406 else: 407 print("Exception not produced") 408 409 410# CHECK-LABEL: TEST: testRankedTensorType 411@run 412def testRankedTensorType(): 413 with Context(), Location.unknown(): 414 f32 = F32Type.get() 415 shape = [2, 3] 416 loc = Location.unknown() 417 # CHECK: ranked tensor type: tensor<2x3xf32> 418 print("ranked tensor type:", RankedTensorType.get(shape, f32)) 419 420 none = NoneType.get() 421 try: 422 tensor_invalid = RankedTensorType.get(shape, none) 423 except MLIRError as e: 424 # CHECK: Invalid type: 425 # CHECK: error: unknown: invalid tensor element type: 'none' 426 print(e) 427 else: 428 print("Exception not produced") 429 430 tensor = RankedTensorType.get(shape, f32, StringAttr.get("encoding")) 431 assert tensor.shape == shape 432 assert tensor.encoding.value == "encoding" 433 434 # Encoding should be None. 435 assert RankedTensorType.get(shape, f32).encoding is None 436 437 438# CHECK-LABEL: TEST: testUnrankedTensorType 439@run 440def testUnrankedTensorType(): 441 with Context(), Location.unknown(): 442 f32 = F32Type.get() 443 loc = Location.unknown() 444 unranked_tensor = UnrankedTensorType.get(f32) 445 # CHECK: unranked tensor type: tensor<*xf32> 446 print("unranked tensor type:", unranked_tensor) 447 try: 448 invalid_rank = unranked_tensor.rank 449 except ValueError as e: 450 # CHECK: calling this method requires that the type has a rank. 451 print(e) 452 else: 453 print("Exception not produced") 454 try: 455 invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0) 456 except ValueError as e: 457 # CHECK: calling this method requires that the type has a rank. 458 print(e) 459 else: 460 print("Exception not produced") 461 try: 462 invalid_get_dim_size = unranked_tensor.get_dim_size(1) 463 except ValueError as e: 464 # CHECK: calling this method requires that the type has a rank. 465 print(e) 466 else: 467 print("Exception not produced") 468 469 none = NoneType.get() 470 try: 471 tensor_invalid = UnrankedTensorType.get(none) 472 except MLIRError as e: 473 # CHECK: Invalid type: 474 # CHECK: error: unknown: invalid tensor element type: 'none' 475 print(e) 476 else: 477 print("Exception not produced") 478 479 480# CHECK-LABEL: TEST: testMemRefType 481@run 482def testMemRefType(): 483 with Context(), Location.unknown(): 484 f32 = F32Type.get() 485 shape = [2, 3] 486 loc = Location.unknown() 487 memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) 488 # CHECK: memref type: memref<2x3xf32, 2> 489 print("memref type:", memref_f32) 490 # CHECK: memref layout: AffineMapAttr(affine_map<(d0, d1) -> (d0, d1)>) 491 print("memref layout:", repr(memref_f32.layout)) 492 # CHECK: memref affine map: (d0, d1) -> (d0, d1) 493 print("memref affine map:", memref_f32.affine_map) 494 # CHECK: memory space: IntegerAttr(2 : i64) 495 print("memory space:", repr(memref_f32.memory_space)) 496 497 layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0])) 498 memref_layout = MemRefType.get(shape, f32, layout=layout) 499 # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>> 500 print("memref type:", memref_layout) 501 # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)> 502 print("memref layout:", memref_layout.layout) 503 # CHECK: memref affine map: (d0, d1) -> (d1, d0) 504 print("memref affine map:", memref_layout.affine_map) 505 # CHECK: memory space: None 506 print("memory space:", memref_layout.memory_space) 507 508 none = NoneType.get() 509 try: 510 memref_invalid = MemRefType.get(shape, none) 511 except MLIRError as e: 512 # CHECK: Invalid type: 513 # CHECK: error: unknown: invalid memref element type 514 print(e) 515 else: 516 print("Exception not produced") 517 518 assert memref_f32.shape == shape 519 520 521# CHECK-LABEL: TEST: testUnrankedMemRefType 522@run 523def testUnrankedMemRefType(): 524 with Context(), Location.unknown(): 525 f32 = F32Type.get() 526 loc = Location.unknown() 527 unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2")) 528 # CHECK: unranked memref type: memref<*xf32, 2> 529 print("unranked memref type:", unranked_memref) 530 # CHECK: memory space: IntegerAttr(2 : i64) 531 print("memory space:", repr(unranked_memref.memory_space)) 532 try: 533 invalid_rank = unranked_memref.rank 534 except ValueError as e: 535 # CHECK: calling this method requires that the type has a rank. 536 print(e) 537 else: 538 print("Exception not produced") 539 try: 540 invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0) 541 except ValueError as e: 542 # CHECK: calling this method requires that the type has a rank. 543 print(e) 544 else: 545 print("Exception not produced") 546 try: 547 invalid_get_dim_size = unranked_memref.get_dim_size(1) 548 except ValueError as e: 549 # CHECK: calling this method requires that the type has a rank. 550 print(e) 551 else: 552 print("Exception not produced") 553 554 none = NoneType.get() 555 try: 556 memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2")) 557 except MLIRError as e: 558 # CHECK: Invalid type: 559 # CHECK: error: unknown: invalid memref element type 560 print(e) 561 else: 562 print("Exception not produced") 563 564 565# CHECK-LABEL: TEST: testTupleType 566@run 567def testTupleType(): 568 with Context() as ctx: 569 i32 = IntegerType(Type.parse("i32")) 570 f32 = F32Type.get() 571 vector = VectorType(Type.parse("vector<2x3xf32>")) 572 l = [i32, f32, vector] 573 tuple_type = TupleType.get_tuple(l) 574 # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>> 575 print("tuple type:", tuple_type) 576 # CHECK: number of types: 3 577 print("number of types:", tuple_type.num_types) 578 # CHECK: pos-th type in the tuple type: f32 579 print("pos-th type in the tuple type:", tuple_type.get_type(1)) 580 581 582# CHECK-LABEL: TEST: testFunctionType 583@run 584def testFunctionType(): 585 with Context() as ctx: 586 input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)] 587 result_types = [IndexType.get()] 588 func = FunctionType.get(input_types, result_types) 589 # CHECK: INPUTS: [IntegerType(i32), IntegerType(i16)] 590 print("INPUTS:", func.inputs) 591 # CHECK: RESULTS: [IndexType(index)] 592 print("RESULTS:", func.results) 593 594 595# CHECK-LABEL: TEST: testOpaqueType 596@run 597def testOpaqueType(): 598 with Context() as ctx: 599 ctx.allow_unregistered_dialects = True 600 opaque = OpaqueType.get("dialect", "type") 601 # CHECK: opaque type: !dialect.type 602 print("opaque type:", opaque) 603 # CHECK: dialect namespace: dialect 604 print("dialect namespace:", opaque.dialect_namespace) 605 # CHECK: data: type 606 print("data:", opaque.data) 607 608 609# CHECK-LABEL: TEST: testShapedTypeConstants 610# Tests that ShapedType exposes magic value constants. 611@run 612def testShapedTypeConstants(): 613 # CHECK: <class 'int'> 614 print(type(ShapedType.get_dynamic_size())) 615 # CHECK: <class 'int'> 616 print(type(ShapedType.get_dynamic_stride_or_offset())) 617 618 619# CHECK-LABEL: TEST: testTypeIDs 620@run 621def testTypeIDs(): 622 with Context(), Location.unknown(): 623 f32 = F32Type.get() 624 625 types = [ 626 (IntegerType, IntegerType.get_signless(16)), 627 (IndexType, IndexType.get()), 628 (Float4E2M1FNType, Float4E2M1FNType.get()), 629 (Float6E2M3FNType, Float6E2M3FNType.get()), 630 (Float6E3M2FNType, Float6E3M2FNType.get()), 631 (Float8E3M4Type, Float8E3M4Type.get()), 632 (Float8E4M3Type, Float8E4M3Type.get()), 633 (Float8E4M3FNType, Float8E4M3FNType.get()), 634 (Float8E5M2Type, Float8E5M2Type.get()), 635 (Float8E4M3FNUZType, Float8E4M3FNUZType.get()), 636 (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()), 637 (Float8E5M2FNUZType, Float8E5M2FNUZType.get()), 638 (Float8E8M0FNUType, Float8E8M0FNUType.get()), 639 (BF16Type, BF16Type.get()), 640 (F16Type, F16Type.get()), 641 (F32Type, F32Type.get()), 642 (FloatTF32Type, FloatTF32Type.get()), 643 (F64Type, F64Type.get()), 644 (NoneType, NoneType.get()), 645 (ComplexType, ComplexType.get(f32)), 646 (VectorType, VectorType.get([2, 3], f32)), 647 (RankedTensorType, RankedTensorType.get([2, 3], f32)), 648 (UnrankedTensorType, UnrankedTensorType.get(f32)), 649 (MemRefType, MemRefType.get([2, 3], f32)), 650 (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))), 651 (TupleType, TupleType.get_tuple([f32])), 652 (FunctionType, FunctionType.get([], [])), 653 (OpaqueType, OpaqueType.get("tensor", "bob")), 654 ] 655 656 # CHECK: IntegerType(i16) 657 # CHECK: IndexType(index) 658 # CHECK: Float4E2M1FNType(f4E2M1FN) 659 # CHECK: Float6E2M3FNType(f6E2M3FN) 660 # CHECK: Float6E3M2FNType(f6E3M2FN) 661 # CHECK: Float8E3M4Type(f8E3M4) 662 # CHECK: Float8E4M3Type(f8E4M3) 663 # CHECK: Float8E4M3FNType(f8E4M3FN) 664 # CHECK: Float8E5M2Type(f8E5M2) 665 # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ) 666 # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ) 667 # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ) 668 # CHECK: Float8E8M0FNUType(f8E8M0FNU) 669 # CHECK: BF16Type(bf16) 670 # CHECK: F16Type(f16) 671 # CHECK: F32Type(f32) 672 # CHECK: FloatTF32Type(tf32) 673 # CHECK: F64Type(f64) 674 # CHECK: NoneType(none) 675 # CHECK: ComplexType(complex<f32>) 676 # CHECK: VectorType(vector<2x3xf32>) 677 # CHECK: RankedTensorType(tensor<2x3xf32>) 678 # CHECK: UnrankedTensorType(tensor<*xf32>) 679 # CHECK: MemRefType(memref<2x3xf32>) 680 # CHECK: UnrankedMemRefType(memref<*xf32, 2>) 681 # CHECK: TupleType(tuple<f32>) 682 # CHECK: FunctionType(() -> ()) 683 # CHECK: OpaqueType(!tensor.bob) 684 for _, t in types: 685 print(repr(t)) 686 687 # Test getTypeIdFunction agrees with 688 # mlirTypeGetTypeID(self) for an instance. 689 # CHECK: all equal 690 for t1, t2 in types: 691 tid1, tid2 = t1.static_typeid, Type(t2).typeid 692 assert tid1 == tid2 and hash(tid1) == hash( 693 tid2 694 ), f"expected hash and value equality {t1} {t2}" 695 else: 696 print("all equal") 697 698 # Test that storing PyTypeID in python dicts 699 # works as expected. 700 typeid_dict = dict(types) 701 assert len(typeid_dict) 702 703 # CHECK: all equal 704 for t1, t2 in typeid_dict.items(): 705 assert t1.static_typeid == t2.typeid and hash(t1.static_typeid) == hash( 706 t2.typeid 707 ), f"expected hash and value equality {t1} {t2}" 708 else: 709 print("all equal") 710 711 # CHECK: ShapedType has no typeid. 712 try: 713 print(ShapedType.static_typeid) 714 except AttributeError as e: 715 print(e) 716 717 vector_type = Type.parse("vector<2x3xf32>") 718 # CHECK: True 719 print(ShapedType(vector_type).typeid == vector_type.typeid) 720 721 722# CHECK-LABEL: TEST: testConcreteTypesRoundTrip 723@run 724def testConcreteTypesRoundTrip(): 725 with Context() as ctx, Location.unknown(): 726 ctx.allow_unregistered_dialects = True 727 728 def print_downcasted(typ): 729 downcasted = Type(typ).maybe_downcast() 730 print(type(downcasted).__name__) 731 print(repr(downcasted)) 732 733 # CHECK: F16Type 734 # CHECK: F16Type(f16) 735 print_downcasted(F16Type.get()) 736 # CHECK: F32Type 737 # CHECK: F32Type(f32) 738 print_downcasted(F32Type.get()) 739 # CHECK: FloatTF32Type 740 # CHECK: FloatTF32Type(tf32) 741 print_downcasted(FloatTF32Type.get()) 742 # CHECK: F64Type 743 # CHECK: F64Type(f64) 744 print_downcasted(F64Type.get()) 745 # CHECK: Float4E2M1FNType 746 # CHECK: Float4E2M1FNType(f4E2M1FN) 747 print_downcasted(Float4E2M1FNType.get()) 748 # CHECK: Float6E2M3FNType 749 # CHECK: Float6E2M3FNType(f6E2M3FN) 750 print_downcasted(Float6E2M3FNType.get()) 751 # CHECK: Float6E3M2FNType 752 # CHECK: Float6E3M2FNType(f6E3M2FN) 753 print_downcasted(Float6E3M2FNType.get()) 754 # CHECK: Float8E3M4Type 755 # CHECK: Float8E3M4Type(f8E3M4) 756 print_downcasted(Float8E3M4Type.get()) 757 # CHECK: Float8E4M3B11FNUZType 758 # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ) 759 print_downcasted(Float8E4M3B11FNUZType.get()) 760 # CHECK: Float8E4M3Type 761 # CHECK: Float8E4M3Type(f8E4M3) 762 print_downcasted(Float8E4M3Type.get()) 763 # CHECK: Float8E4M3FNType 764 # CHECK: Float8E4M3FNType(f8E4M3FN) 765 print_downcasted(Float8E4M3FNType.get()) 766 # CHECK: Float8E4M3FNUZType 767 # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ) 768 print_downcasted(Float8E4M3FNUZType.get()) 769 # CHECK: Float8E5M2Type 770 # CHECK: Float8E5M2Type(f8E5M2) 771 print_downcasted(Float8E5M2Type.get()) 772 # CHECK: Float8E5M2FNUZType 773 # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ) 774 print_downcasted(Float8E5M2FNUZType.get()) 775 # CHECK: Float8E8M0FNUType 776 # CHECK: Float8E8M0FNUType(f8E8M0FNU) 777 print_downcasted(Float8E8M0FNUType.get()) 778 # CHECK: BF16Type 779 # CHECK: BF16Type(bf16) 780 print_downcasted(BF16Type.get()) 781 # CHECK: IndexType 782 # CHECK: IndexType(index) 783 print_downcasted(IndexType.get()) 784 # CHECK: IntegerType 785 # CHECK: IntegerType(i32) 786 print_downcasted(IntegerType.get_signless(32)) 787 788 f32 = F32Type.get() 789 ranked_tensor = tensor.EmptyOp([10, 10], f32).result 790 # CHECK: RankedTensorType 791 print(type(ranked_tensor.type).__name__) 792 # CHECK: RankedTensorType(tensor<10x10xf32>) 793 print(repr(ranked_tensor.type)) 794 795 cf32 = ComplexType.get(f32) 796 # CHECK: ComplexType 797 print(type(cf32).__name__) 798 # CHECK: ComplexType(complex<f32>) 799 print(repr(cf32)) 800 801 ranked_tensor = tensor.EmptyOp([10, 10], f32).result 802 # CHECK: RankedTensorType 803 print(type(ranked_tensor.type).__name__) 804 # CHECK: RankedTensorType(tensor<10x10xf32>) 805 print(repr(ranked_tensor.type)) 806 807 vector = VectorType.get([10, 10], f32) 808 tuple_type = TupleType.get_tuple([f32, vector]) 809 # CHECK: TupleType 810 print(type(tuple_type).__name__) 811 # CHECK: TupleType(tuple<f32, vector<10x10xf32>>) 812 print(repr(tuple_type)) 813 # CHECK: F32Type(f32) 814 print(repr(tuple_type.get_type(0))) 815 # CHECK: VectorType(vector<10x10xf32>) 816 print(repr(tuple_type.get_type(1))) 817 818 index_type = IndexType.get() 819 820 @func.FuncOp.from_py_func() 821 def default_builder(): 822 c0 = arith.ConstantOp(f32, 0.0) 823 unranked_tensor_type = UnrankedTensorType.get(f32) 824 unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result 825 # CHECK: UnrankedTensorType 826 print(type(unranked_tensor.type).__name__) 827 # CHECK: UnrankedTensorType(tensor<*xf32>) 828 print(repr(unranked_tensor.type)) 829 830 c10 = arith.ConstantOp(index_type, 10) 831 memref_f32_t = MemRefType.get([10, 10], f32) 832 memref_f32 = memref.AllocOp(memref_f32_t, [c10, c10], []).result 833 # CHECK: MemRefType 834 print(type(memref_f32.type).__name__) 835 # CHECK: MemRefType(memref<10x10xf32>) 836 print(repr(memref_f32.type)) 837 838 unranked_memref_t = UnrankedMemRefType.get(f32, Attribute.parse("2")) 839 memref_f32 = memref.AllocOp(unranked_memref_t, [c10, c10], []).result 840 # CHECK: UnrankedMemRefType 841 print(type(memref_f32.type).__name__) 842 # CHECK: UnrankedMemRefType(memref<*xf32, 2>) 843 print(repr(memref_f32.type)) 844 845 tuple_type = Operation.parse( 846 f'"test.make_tuple"() : () -> tuple<i32, f32>' 847 ).result 848 # CHECK: TupleType 849 print(type(tuple_type.type).__name__) 850 # CHECK: TupleType(tuple<i32, f32>) 851 print(repr(tuple_type.type)) 852 853 return c0, c10 854 855 856# CHECK-LABEL: TEST: testCustomTypeTypeCaster 857# This tests being able to materialize a type from a dialect *and* have 858# the implemented type caster called without explicitly importing the dialect. 859# I.e., we get a transform.OperationType without explicitly importing the transform dialect. 860@run 861def testCustomTypeTypeCaster(): 862 with Context() as ctx, Location.unknown(): 863 t = Type.parse('!transform.op<"foo.bar">', Context()) 864 # CHECK: !transform.op<"foo.bar"> 865 print(t) 866 # CHECK: OperationType(!transform.op<"foo.bar">) 867 print(repr(t)) 868 869 870# CHECK-LABEL: TEST: testTypeWrappers 871@run 872def testTypeWrappers(): 873 def stride(strides, offset=0): 874 return StridedLayoutAttr.get(offset, strides) 875 876 with Context(), Location.unknown(): 877 ia = T.i(5) 878 sia = T.si(6) 879 uia = T.ui(7) 880 assert repr(ia) == "IntegerType(i5)" 881 assert repr(sia) == "IntegerType(si6)" 882 assert repr(uia) == "IntegerType(ui7)" 883 884 assert T.i(16) == T.i16() 885 assert T.si(16) == T.si16() 886 assert T.ui(16) == T.ui16() 887 888 c1 = T.complex(T.f16()) 889 c2 = T.complex(T.i32()) 890 assert repr(c1) == "ComplexType(complex<f16>)" 891 assert repr(c2) == "ComplexType(complex<i32>)" 892 893 vec_1 = T.vector(2, 3, T.f32()) 894 vec_2 = T.vector(2, 3, 4, T.f32()) 895 assert repr(vec_1) == "VectorType(vector<2x3xf32>)" 896 assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)" 897 898 m1 = T.memref(2, 3, 4, T.f64()) 899 assert repr(m1) == "MemRefType(memref<2x3x4xf64>)" 900 901 m2 = T.memref(2, 3, 4, T.f64(), memory_space=1) 902 assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)" 903 904 m3 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13])) 905 assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)" 906 907 m4 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13], 42)) 908 assert ( 909 repr(m4) 910 == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)" 911 ) 912 913 S = ShapedType.get_dynamic_size() 914 915 t1 = T.tensor(S, 3, S, T.f64()) 916 assert repr(t1) == "RankedTensorType(tensor<?x3x?xf64>)" 917 ut1 = T.tensor(T.f64()) 918 assert repr(ut1) == "UnrankedTensorType(tensor<*xf64>)" 919 t2 = T.tensor(S, 3, S, element_type=T.f64()) 920 assert repr(t2) == "RankedTensorType(tensor<?x3x?xf64>)" 921 ut2 = T.tensor(element_type=T.f64()) 922 assert repr(ut2) == "UnrankedTensorType(tensor<*xf64>)" 923 924 t3 = T.tensor(S, 3, S, T.f64(), encoding="encoding") 925 assert repr(t3) == 'RankedTensorType(tensor<?x3x?xf64, "encoding">)' 926 927 v = T.vector(3, 3, 3, T.f64()) 928 assert repr(v) == "VectorType(vector<3x3x3xf64>)" 929 930 m5 = T.memref(S, 3, S, T.f64()) 931 assert repr(m5) == "MemRefType(memref<?x3x?xf64>)" 932 um1 = T.memref(T.f64()) 933 assert repr(um1) == "UnrankedMemRefType(memref<*xf64>)" 934 m6 = T.memref(S, 3, S, element_type=T.f64()) 935 assert repr(m6) == "MemRefType(memref<?x3x?xf64>)" 936 um2 = T.memref(element_type=T.f64()) 937 assert repr(um2) == "UnrankedMemRefType(memref<*xf64>)" 938 939 m7 = T.memref(S, 3, S, T.f64()) 940 assert repr(m7) == "MemRefType(memref<?x3x?xf64>)" 941 um3 = T.memref(T.f64()) 942 assert repr(um3) == "UnrankedMemRefType(memref<*xf64>)" 943 944 scalable_1 = T.vector(2, 3, T.f32(), scalable=[False, True]) 945 scalable_2 = T.vector(2, 3, 4, T.f32(), scalable=[True, False, True]) 946 assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)" 947 assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)" 948 949 scalable_3 = T.vector(2, 3, T.f32(), scalable_dims=[1]) 950 scalable_4 = T.vector(2, 3, 4, T.f32(), scalable_dims=[0, 2]) 951 assert scalable_3 == scalable_1 952 assert scalable_4 == scalable_2 953 954 opaq = T.opaque("scf", "placeholder") 955 assert repr(opaq) == "OpaqueType(!scf.placeholder)" 956 957 tup1 = T.tuple(T.i16(), T.i32(), T.i64()) 958 tup2 = T.tuple(T.f16(), T.f32(), T.f64()) 959 assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)" 960 assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)" 961 962 func = T.function( 963 inputs=(T.i16(), T.i32(), T.i64()), results=(T.f16(), T.f32(), T.f64()) 964 ) 965 assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))" 966