xref: /llvm-project/mlir/test/python/ir/builtin_types.py (revision e17c91341be2f6a2d229ab44a4290e7d0ef2e094)
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5from mlir.dialects import arith, tensor, func, memref
6import mlir.extras.types as T
7
8
9def run(f):
10    print("\nTEST:", f.__name__)
11    f()
12    gc.collect()
13    assert Context._get_live_count() == 0
14    return f
15
16
17# CHECK-LABEL: TEST: testParsePrint
18@run
19def testParsePrint():
20    ctx = Context()
21    t = Type.parse("i32", ctx)
22    assert t.context is ctx
23    ctx = None
24    gc.collect()
25    # CHECK: i32
26    print(str(t))
27    # CHECK: Type(i32)
28    print(repr(t))
29
30
31# CHECK-LABEL: TEST: testParseError
32@run
33def testParseError():
34    ctx = Context()
35    try:
36        t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx)
37    except MLIRError as e:
38        # CHECK: testParseError: <
39        # CHECK:   Unable to parse type:
40        # CHECK:   error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type
41        # CHECK: >
42        print(f"testParseError: <{e}>")
43    else:
44        print("Exception not produced")
45
46
47# CHECK-LABEL: TEST: testTypeEq
48@run
49def testTypeEq():
50    ctx = Context()
51    t1 = Type.parse("i32", ctx)
52    t2 = Type.parse("f32", ctx)
53    t3 = Type.parse("i32", ctx)
54    # CHECK: t1 == t1: True
55    print("t1 == t1:", t1 == t1)
56    # CHECK: t1 == t2: False
57    print("t1 == t2:", t1 == t2)
58    # CHECK: t1 == t3: True
59    print("t1 == t3:", t1 == t3)
60    # CHECK: t1 is None: False
61    print("t1 is None:", t1 is None)
62
63
64# CHECK-LABEL: TEST: testTypeHash
65@run
66def testTypeHash():
67    ctx = Context()
68    t1 = Type.parse("i32", ctx)
69    t2 = Type.parse("f32", ctx)
70    t3 = Type.parse("i32", ctx)
71
72    # CHECK: hash(t1) == hash(t3): True
73    print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__())
74
75    s = set()
76    s.add(t1)
77    s.add(t2)
78    s.add(t3)
79    # CHECK: len(s): 2
80    print("len(s): ", len(s))
81
82
83# CHECK-LABEL: TEST: testTypeCast
84@run
85def testTypeCast():
86    ctx = Context()
87    t1 = Type.parse("i32", ctx)
88    t2 = Type(t1)
89    # CHECK: t1 == t2: True
90    print("t1 == t2:", t1 == t2)
91
92
93# CHECK-LABEL: TEST: testTypeIsInstance
94@run
95def testTypeIsInstance():
96    ctx = Context()
97    t1 = Type.parse("i32", ctx)
98    t2 = Type.parse("f32", ctx)
99    # CHECK: True
100    print(IntegerType.isinstance(t1))
101    # CHECK: False
102    print(F32Type.isinstance(t1))
103    # CHECK: False
104    print(FloatType.isinstance(t1))
105    # CHECK: True
106    print(F32Type.isinstance(t2))
107    # CHECK: True
108    print(FloatType.isinstance(t2))
109
110
111# CHECK-LABEL: TEST: testFloatTypeSubclasses
112@run
113def testFloatTypeSubclasses():
114    ctx = Context()
115    # CHECK: True
116    print(isinstance(Type.parse("f4E2M1FN", ctx), FloatType))
117    # CHECK: True
118    print(isinstance(Type.parse("f6E2M3FN", ctx), FloatType))
119    # CHECK: True
120    print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType))
121    # CHECK: True
122    print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
123    # CHECK: True
124    print(isinstance(Type.parse("f8E4M3", ctx), FloatType))
125    # CHECK: True
126    print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
127    # CHECK: True
128    print(isinstance(Type.parse("f8E5M2", ctx), FloatType))
129    # CHECK: True
130    print(isinstance(Type.parse("f8E4M3FNUZ", ctx), FloatType))
131    # CHECK: True
132    print(isinstance(Type.parse("f8E4M3B11FNUZ", ctx), FloatType))
133    # CHECK: True
134    print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType))
135    # CHECK: True
136    print(isinstance(Type.parse("f8E8M0FNU", ctx), FloatType))
137    # CHECK: True
138    print(isinstance(Type.parse("f16", ctx), FloatType))
139    # CHECK: True
140    print(isinstance(Type.parse("bf16", ctx), FloatType))
141    # CHECK: True
142    print(isinstance(Type.parse("f32", ctx), FloatType))
143    # CHECK: True
144    print(isinstance(Type.parse("tf32", ctx), FloatType))
145    # CHECK: True
146    print(isinstance(Type.parse("f64", ctx), FloatType))
147
148
149# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
150@run
151def testTypeEqDoesNotRaise():
152    ctx = Context()
153    t1 = Type.parse("i32", ctx)
154    not_a_type = "foo"
155    # CHECK: False
156    print(t1 == not_a_type)
157    # CHECK: False
158    print(t1 is None)
159    # CHECK: True
160    print(t1 is not None)
161
162
163# CHECK-LABEL: TEST: testTypeCapsule
164@run
165def testTypeCapsule():
166    with Context() as ctx:
167        t1 = Type.parse("i32", ctx)
168    # CHECK: mlir.ir.Type._CAPIPtr
169    type_capsule = t1._CAPIPtr
170    print(type_capsule)
171    t2 = Type._CAPICreate(type_capsule)
172    assert t2 == t1
173    assert t2.context is ctx
174
175
176# CHECK-LABEL: TEST: testStandardTypeCasts
177@run
178def testStandardTypeCasts():
179    ctx = Context()
180    t1 = Type.parse("i32", ctx)
181    tint = IntegerType(t1)
182    tself = IntegerType(tint)
183    # CHECK: Type(i32)
184    print(repr(tint))
185    try:
186        tillegal = IntegerType(Type.parse("f32", ctx))
187    except ValueError as e:
188        # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
189        print("ValueError:", e)
190    else:
191        print("Exception not produced")
192
193
194# CHECK-LABEL: TEST: testIntegerType
195@run
196def testIntegerType():
197    with Context() as ctx:
198        i32 = IntegerType(Type.parse("i32"))
199        # CHECK: i32 width: 32
200        print("i32 width:", i32.width)
201        # CHECK: i32 signless: True
202        print("i32 signless:", i32.is_signless)
203        # CHECK: i32 signed: False
204        print("i32 signed:", i32.is_signed)
205        # CHECK: i32 unsigned: False
206        print("i32 unsigned:", i32.is_unsigned)
207
208        s32 = IntegerType(Type.parse("si32"))
209        # CHECK: s32 signless: False
210        print("s32 signless:", s32.is_signless)
211        # CHECK: s32 signed: True
212        print("s32 signed:", s32.is_signed)
213        # CHECK: s32 unsigned: False
214        print("s32 unsigned:", s32.is_unsigned)
215
216        u32 = IntegerType(Type.parse("ui32"))
217        # CHECK: u32 signless: False
218        print("u32 signless:", u32.is_signless)
219        # CHECK: u32 signed: False
220        print("u32 signed:", u32.is_signed)
221        # CHECK: u32 unsigned: True
222        print("u32 unsigned:", u32.is_unsigned)
223
224        # CHECK: signless: i16
225        print("signless:", IntegerType.get_signless(16))
226        # CHECK: signed: si8
227        print("signed:", IntegerType.get_signed(8))
228        # CHECK: unsigned: ui64
229        print("unsigned:", IntegerType.get_unsigned(64))
230
231
232# CHECK-LABEL: TEST: testIndexType
233@run
234def testIndexType():
235    with Context() as ctx:
236        # CHECK: index type: index
237        print("index type:", IndexType.get())
238
239
240# CHECK-LABEL: TEST: testFloatType
241@run
242def testFloatType():
243    with Context():
244        # CHECK: float: f4E2M1FN
245        print("float:", Float4E2M1FNType.get())
246        # CHECK: float: f6E2M3FN
247        print("float:", Float6E2M3FNType.get())
248        # CHECK: float: f6E3M2FN
249        print("float:", Float6E3M2FNType.get())
250        # CHECK: float: f8E3M4
251        print("float:", Float8E3M4Type.get())
252        # CHECK: float: f8E4M3
253        print("float:", Float8E4M3Type.get())
254        # CHECK: float: f8E4M3FN
255        print("float:", Float8E4M3FNType.get())
256        # CHECK: float: f8E5M2
257        print("float:", Float8E5M2Type.get())
258        # CHECK: float: f8E5M2FNUZ
259        print("float:", Float8E5M2FNUZType.get())
260        # CHECK: float: f8E4M3FNUZ
261        print("float:", Float8E4M3FNUZType.get())
262        # CHECK: float: f8E4M3B11FNUZ
263        print("float:", Float8E4M3B11FNUZType.get())
264        # CHECK: float: f8E8M0FNU
265        print("float:", Float8E8M0FNUType.get())
266        # CHECK: float: bf16
267        print("float:", BF16Type.get())
268        # CHECK: float: f16
269        print("float:", F16Type.get())
270        # CHECK: float: tf32
271        print("float:", FloatTF32Type.get())
272        # CHECK: float: f32
273        print("float:", F32Type.get())
274        # CHECK: float: f64
275        f64 = F64Type.get()
276        print("float:", f64)
277        # CHECK: f64 width: 64
278        print("f64 width:", f64.width)
279
280
281# CHECK-LABEL: TEST: testNoneType
282@run
283def testNoneType():
284    with Context():
285        # CHECK: none type: none
286        print("none type:", NoneType.get())
287
288
289# CHECK-LABEL: TEST: testComplexType
290@run
291def testComplexType():
292    with Context() as ctx:
293        complex_i32 = ComplexType(Type.parse("complex<i32>"))
294        # CHECK: complex type element: i32
295        print("complex type element:", complex_i32.element_type)
296
297        f32 = F32Type.get()
298        # CHECK: complex type: complex<f32>
299        print("complex type:", ComplexType.get(f32))
300
301        index = IndexType.get()
302        try:
303            complex_invalid = ComplexType.get(index)
304        except ValueError as e:
305            # CHECK: invalid 'Type(index)' and expected floating point or integer type.
306            print(e)
307        else:
308            print("Exception not produced")
309
310
311# CHECK-LABEL: TEST: testConcreteShapedType
312# Shaped type is not a kind of builtin types, it is the base class for vectors,
313# memrefs and tensors, so this test case uses an instance of vector to test the
314# shaped type. The class hierarchy is preserved on the python side.
315@run
316def testConcreteShapedType():
317    with Context() as ctx:
318        vector = VectorType(Type.parse("vector<2x3xf32>"))
319        # CHECK: element type: f32
320        print("element type:", vector.element_type)
321        # CHECK: whether the given shaped type is ranked: True
322        print("whether the given shaped type is ranked:", vector.has_rank)
323        # CHECK: rank: 2
324        print("rank:", vector.rank)
325        # CHECK: whether the shaped type has a static shape: True
326        print("whether the shaped type has a static shape:", vector.has_static_shape)
327        # CHECK: whether the dim-th dimension is dynamic: False
328        print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
329        # CHECK: dim size: 3
330        print("dim size:", vector.get_dim_size(1))
331        # CHECK: is_dynamic_size: False
332        print("is_dynamic_size:", vector.is_dynamic_size(3))
333        # CHECK: is_dynamic_stride_or_offset: False
334        print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
335        # CHECK: isinstance(ShapedType): True
336        print("isinstance(ShapedType):", isinstance(vector, ShapedType))
337
338
339# CHECK-LABEL: TEST: testAbstractShapedType
340# Tests that ShapedType operates as an abstract base class of a concrete
341# shaped type (using vector as an example).
342@run
343def testAbstractShapedType():
344    ctx = Context()
345    vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
346    # CHECK: element type: f32
347    print("element type:", vector.element_type)
348
349
350# CHECK-LABEL: TEST: testVectorType
351@run
352def testVectorType():
353    with Context(), Location.unknown():
354        f32 = F32Type.get()
355        shape = [2, 3]
356        # CHECK: vector type: vector<2x3xf32>
357        print("vector type:", VectorType.get(shape, f32))
358
359        none = NoneType.get()
360        try:
361            VectorType.get(shape, none)
362        except MLIRError as e:
363            # CHECK: Invalid type:
364            # CHECK: error: unknown: failed to verify 'elementType': integer or index or floating-point
365            print(e)
366        else:
367            print("Exception not produced")
368
369        scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
370        scalable_2 = VectorType.get([2, 3, 4], f32, scalable=[True, False, True])
371        assert scalable_1.scalable
372        assert scalable_2.scalable
373        assert scalable_1.scalable_dims == [False, True]
374        assert scalable_2.scalable_dims == [True, False, True]
375        # CHECK: scalable 1: vector<2x[3]xf32>
376        print("scalable 1: ", scalable_1)
377        # CHECK: scalable 2: vector<[2]x3x[4]xf32>
378        print("scalable 2: ", scalable_2)
379
380        scalable_3 = VectorType.get(shape, f32, scalable_dims=[1])
381        scalable_4 = VectorType.get([2, 3, 4], f32, scalable_dims=[0, 2])
382        assert scalable_3 == scalable_1
383        assert scalable_4 == scalable_2
384
385        try:
386            VectorType.get(shape, f32, scalable=[False, True, True])
387        except ValueError as e:
388            # CHECK: Expected len(scalable) == len(shape).
389            print(e)
390        else:
391            print("Exception not produced")
392
393        try:
394            VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[1])
395        except ValueError as e:
396            # CHECK: kwargs are mutually exclusive.
397            print(e)
398        else:
399            print("Exception not produced")
400
401        try:
402            VectorType.get(shape, f32, scalable_dims=[42])
403        except ValueError as e:
404            # CHECK: Scalable dimension index out of bounds.
405            print(e)
406        else:
407            print("Exception not produced")
408
409
410# CHECK-LABEL: TEST: testRankedTensorType
411@run
412def testRankedTensorType():
413    with Context(), Location.unknown():
414        f32 = F32Type.get()
415        shape = [2, 3]
416        loc = Location.unknown()
417        # CHECK: ranked tensor type: tensor<2x3xf32>
418        print("ranked tensor type:", RankedTensorType.get(shape, f32))
419
420        none = NoneType.get()
421        try:
422            tensor_invalid = RankedTensorType.get(shape, none)
423        except MLIRError as e:
424            # CHECK: Invalid type:
425            # CHECK: error: unknown: invalid tensor element type: 'none'
426            print(e)
427        else:
428            print("Exception not produced")
429
430        tensor = RankedTensorType.get(shape, f32, StringAttr.get("encoding"))
431        assert tensor.shape == shape
432        assert tensor.encoding.value == "encoding"
433
434        # Encoding should be None.
435        assert RankedTensorType.get(shape, f32).encoding is None
436
437
438# CHECK-LABEL: TEST: testUnrankedTensorType
439@run
440def testUnrankedTensorType():
441    with Context(), Location.unknown():
442        f32 = F32Type.get()
443        loc = Location.unknown()
444        unranked_tensor = UnrankedTensorType.get(f32)
445        # CHECK: unranked tensor type: tensor<*xf32>
446        print("unranked tensor type:", unranked_tensor)
447        try:
448            invalid_rank = unranked_tensor.rank
449        except ValueError as e:
450            # CHECK: calling this method requires that the type has a rank.
451            print(e)
452        else:
453            print("Exception not produced")
454        try:
455            invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0)
456        except ValueError as e:
457            # CHECK: calling this method requires that the type has a rank.
458            print(e)
459        else:
460            print("Exception not produced")
461        try:
462            invalid_get_dim_size = unranked_tensor.get_dim_size(1)
463        except ValueError as e:
464            # CHECK: calling this method requires that the type has a rank.
465            print(e)
466        else:
467            print("Exception not produced")
468
469        none = NoneType.get()
470        try:
471            tensor_invalid = UnrankedTensorType.get(none)
472        except MLIRError as e:
473            # CHECK: Invalid type:
474            # CHECK: error: unknown: invalid tensor element type: 'none'
475            print(e)
476        else:
477            print("Exception not produced")
478
479
480# CHECK-LABEL: TEST: testMemRefType
481@run
482def testMemRefType():
483    with Context(), Location.unknown():
484        f32 = F32Type.get()
485        shape = [2, 3]
486        loc = Location.unknown()
487        memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
488        # CHECK: memref type: memref<2x3xf32, 2>
489        print("memref type:", memref_f32)
490        # CHECK: memref layout: AffineMapAttr(affine_map<(d0, d1) -> (d0, d1)>)
491        print("memref layout:", repr(memref_f32.layout))
492        # CHECK: memref affine map: (d0, d1) -> (d0, d1)
493        print("memref affine map:", memref_f32.affine_map)
494        # CHECK: memory space: IntegerAttr(2 : i64)
495        print("memory space:", repr(memref_f32.memory_space))
496
497        layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
498        memref_layout = MemRefType.get(shape, f32, layout=layout)
499        # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
500        print("memref type:", memref_layout)
501        # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)>
502        print("memref layout:", memref_layout.layout)
503        # CHECK: memref affine map: (d0, d1) -> (d1, d0)
504        print("memref affine map:", memref_layout.affine_map)
505        # CHECK: memory space: None
506        print("memory space:", memref_layout.memory_space)
507
508        none = NoneType.get()
509        try:
510            memref_invalid = MemRefType.get(shape, none)
511        except MLIRError as e:
512            # CHECK: Invalid type:
513            # CHECK: error: unknown: invalid memref element type
514            print(e)
515        else:
516            print("Exception not produced")
517
518    assert memref_f32.shape == shape
519
520
521# CHECK-LABEL: TEST: testUnrankedMemRefType
522@run
523def testUnrankedMemRefType():
524    with Context(), Location.unknown():
525        f32 = F32Type.get()
526        loc = Location.unknown()
527        unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2"))
528        # CHECK: unranked memref type: memref<*xf32, 2>
529        print("unranked memref type:", unranked_memref)
530        # CHECK: memory space: IntegerAttr(2 : i64)
531        print("memory space:", repr(unranked_memref.memory_space))
532        try:
533            invalid_rank = unranked_memref.rank
534        except ValueError as e:
535            # CHECK: calling this method requires that the type has a rank.
536            print(e)
537        else:
538            print("Exception not produced")
539        try:
540            invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0)
541        except ValueError as e:
542            # CHECK: calling this method requires that the type has a rank.
543            print(e)
544        else:
545            print("Exception not produced")
546        try:
547            invalid_get_dim_size = unranked_memref.get_dim_size(1)
548        except ValueError as e:
549            # CHECK: calling this method requires that the type has a rank.
550            print(e)
551        else:
552            print("Exception not produced")
553
554        none = NoneType.get()
555        try:
556            memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2"))
557        except MLIRError as e:
558            # CHECK: Invalid type:
559            # CHECK: error: unknown: invalid memref element type
560            print(e)
561        else:
562            print("Exception not produced")
563
564
565# CHECK-LABEL: TEST: testTupleType
566@run
567def testTupleType():
568    with Context() as ctx:
569        i32 = IntegerType(Type.parse("i32"))
570        f32 = F32Type.get()
571        vector = VectorType(Type.parse("vector<2x3xf32>"))
572        l = [i32, f32, vector]
573        tuple_type = TupleType.get_tuple(l)
574        # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>>
575        print("tuple type:", tuple_type)
576        # CHECK: number of types: 3
577        print("number of types:", tuple_type.num_types)
578        # CHECK: pos-th type in the tuple type: f32
579        print("pos-th type in the tuple type:", tuple_type.get_type(1))
580
581
582# CHECK-LABEL: TEST: testFunctionType
583@run
584def testFunctionType():
585    with Context() as ctx:
586        input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
587        result_types = [IndexType.get()]
588        func = FunctionType.get(input_types, result_types)
589        # CHECK: INPUTS: [IntegerType(i32), IntegerType(i16)]
590        print("INPUTS:", func.inputs)
591        # CHECK: RESULTS: [IndexType(index)]
592        print("RESULTS:", func.results)
593
594
595# CHECK-LABEL: TEST: testOpaqueType
596@run
597def testOpaqueType():
598    with Context() as ctx:
599        ctx.allow_unregistered_dialects = True
600        opaque = OpaqueType.get("dialect", "type")
601        # CHECK: opaque type: !dialect.type
602        print("opaque type:", opaque)
603        # CHECK: dialect namespace: dialect
604        print("dialect namespace:", opaque.dialect_namespace)
605        # CHECK: data: type
606        print("data:", opaque.data)
607
608
609# CHECK-LABEL: TEST: testShapedTypeConstants
610# Tests that ShapedType exposes magic value constants.
611@run
612def testShapedTypeConstants():
613    # CHECK: <class 'int'>
614    print(type(ShapedType.get_dynamic_size()))
615    # CHECK: <class 'int'>
616    print(type(ShapedType.get_dynamic_stride_or_offset()))
617
618
619# CHECK-LABEL: TEST: testTypeIDs
620@run
621def testTypeIDs():
622    with Context(), Location.unknown():
623        f32 = F32Type.get()
624
625        types = [
626            (IntegerType, IntegerType.get_signless(16)),
627            (IndexType, IndexType.get()),
628            (Float4E2M1FNType, Float4E2M1FNType.get()),
629            (Float6E2M3FNType, Float6E2M3FNType.get()),
630            (Float6E3M2FNType, Float6E3M2FNType.get()),
631            (Float8E3M4Type, Float8E3M4Type.get()),
632            (Float8E4M3Type, Float8E4M3Type.get()),
633            (Float8E4M3FNType, Float8E4M3FNType.get()),
634            (Float8E5M2Type, Float8E5M2Type.get()),
635            (Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
636            (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()),
637            (Float8E5M2FNUZType, Float8E5M2FNUZType.get()),
638            (Float8E8M0FNUType, Float8E8M0FNUType.get()),
639            (BF16Type, BF16Type.get()),
640            (F16Type, F16Type.get()),
641            (F32Type, F32Type.get()),
642            (FloatTF32Type, FloatTF32Type.get()),
643            (F64Type, F64Type.get()),
644            (NoneType, NoneType.get()),
645            (ComplexType, ComplexType.get(f32)),
646            (VectorType, VectorType.get([2, 3], f32)),
647            (RankedTensorType, RankedTensorType.get([2, 3], f32)),
648            (UnrankedTensorType, UnrankedTensorType.get(f32)),
649            (MemRefType, MemRefType.get([2, 3], f32)),
650            (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))),
651            (TupleType, TupleType.get_tuple([f32])),
652            (FunctionType, FunctionType.get([], [])),
653            (OpaqueType, OpaqueType.get("tensor", "bob")),
654        ]
655
656        # CHECK: IntegerType(i16)
657        # CHECK: IndexType(index)
658        # CHECK: Float4E2M1FNType(f4E2M1FN)
659        # CHECK: Float6E2M3FNType(f6E2M3FN)
660        # CHECK: Float6E3M2FNType(f6E3M2FN)
661        # CHECK: Float8E3M4Type(f8E3M4)
662        # CHECK: Float8E4M3Type(f8E4M3)
663        # CHECK: Float8E4M3FNType(f8E4M3FN)
664        # CHECK: Float8E5M2Type(f8E5M2)
665        # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
666        # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
667        # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
668        # CHECK: Float8E8M0FNUType(f8E8M0FNU)
669        # CHECK: BF16Type(bf16)
670        # CHECK: F16Type(f16)
671        # CHECK: F32Type(f32)
672        # CHECK: FloatTF32Type(tf32)
673        # CHECK: F64Type(f64)
674        # CHECK: NoneType(none)
675        # CHECK: ComplexType(complex<f32>)
676        # CHECK: VectorType(vector<2x3xf32>)
677        # CHECK: RankedTensorType(tensor<2x3xf32>)
678        # CHECK: UnrankedTensorType(tensor<*xf32>)
679        # CHECK: MemRefType(memref<2x3xf32>)
680        # CHECK: UnrankedMemRefType(memref<*xf32, 2>)
681        # CHECK: TupleType(tuple<f32>)
682        # CHECK: FunctionType(() -> ())
683        # CHECK: OpaqueType(!tensor.bob)
684        for _, t in types:
685            print(repr(t))
686
687        # Test getTypeIdFunction agrees with
688        # mlirTypeGetTypeID(self) for an instance.
689        # CHECK: all equal
690        for t1, t2 in types:
691            tid1, tid2 = t1.static_typeid, Type(t2).typeid
692            assert tid1 == tid2 and hash(tid1) == hash(
693                tid2
694            ), f"expected hash and value equality {t1} {t2}"
695        else:
696            print("all equal")
697
698        # Test that storing PyTypeID in python dicts
699        # works as expected.
700        typeid_dict = dict(types)
701        assert len(typeid_dict)
702
703        # CHECK: all equal
704        for t1, t2 in typeid_dict.items():
705            assert t1.static_typeid == t2.typeid and hash(t1.static_typeid) == hash(
706                t2.typeid
707            ), f"expected hash and value equality {t1} {t2}"
708        else:
709            print("all equal")
710
711        # CHECK: ShapedType has no typeid.
712        try:
713            print(ShapedType.static_typeid)
714        except AttributeError as e:
715            print(e)
716
717        vector_type = Type.parse("vector<2x3xf32>")
718        # CHECK: True
719        print(ShapedType(vector_type).typeid == vector_type.typeid)
720
721
722# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
723@run
724def testConcreteTypesRoundTrip():
725    with Context() as ctx, Location.unknown():
726        ctx.allow_unregistered_dialects = True
727
728        def print_downcasted(typ):
729            downcasted = Type(typ).maybe_downcast()
730            print(type(downcasted).__name__)
731            print(repr(downcasted))
732
733        # CHECK: F16Type
734        # CHECK: F16Type(f16)
735        print_downcasted(F16Type.get())
736        # CHECK: F32Type
737        # CHECK: F32Type(f32)
738        print_downcasted(F32Type.get())
739        # CHECK: FloatTF32Type
740        # CHECK: FloatTF32Type(tf32)
741        print_downcasted(FloatTF32Type.get())
742        # CHECK: F64Type
743        # CHECK: F64Type(f64)
744        print_downcasted(F64Type.get())
745        # CHECK: Float4E2M1FNType
746        # CHECK: Float4E2M1FNType(f4E2M1FN)
747        print_downcasted(Float4E2M1FNType.get())
748        # CHECK: Float6E2M3FNType
749        # CHECK: Float6E2M3FNType(f6E2M3FN)
750        print_downcasted(Float6E2M3FNType.get())
751        # CHECK: Float6E3M2FNType
752        # CHECK: Float6E3M2FNType(f6E3M2FN)
753        print_downcasted(Float6E3M2FNType.get())
754        # CHECK: Float8E3M4Type
755        # CHECK: Float8E3M4Type(f8E3M4)
756        print_downcasted(Float8E3M4Type.get())
757        # CHECK: Float8E4M3B11FNUZType
758        # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
759        print_downcasted(Float8E4M3B11FNUZType.get())
760        # CHECK: Float8E4M3Type
761        # CHECK: Float8E4M3Type(f8E4M3)
762        print_downcasted(Float8E4M3Type.get())
763        # CHECK: Float8E4M3FNType
764        # CHECK: Float8E4M3FNType(f8E4M3FN)
765        print_downcasted(Float8E4M3FNType.get())
766        # CHECK: Float8E4M3FNUZType
767        # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
768        print_downcasted(Float8E4M3FNUZType.get())
769        # CHECK: Float8E5M2Type
770        # CHECK: Float8E5M2Type(f8E5M2)
771        print_downcasted(Float8E5M2Type.get())
772        # CHECK: Float8E5M2FNUZType
773        # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
774        print_downcasted(Float8E5M2FNUZType.get())
775        # CHECK: Float8E8M0FNUType
776        # CHECK: Float8E8M0FNUType(f8E8M0FNU)
777        print_downcasted(Float8E8M0FNUType.get())
778        # CHECK: BF16Type
779        # CHECK: BF16Type(bf16)
780        print_downcasted(BF16Type.get())
781        # CHECK: IndexType
782        # CHECK: IndexType(index)
783        print_downcasted(IndexType.get())
784        # CHECK: IntegerType
785        # CHECK: IntegerType(i32)
786        print_downcasted(IntegerType.get_signless(32))
787
788        f32 = F32Type.get()
789        ranked_tensor = tensor.EmptyOp([10, 10], f32).result
790        # CHECK: RankedTensorType
791        print(type(ranked_tensor.type).__name__)
792        # CHECK: RankedTensorType(tensor<10x10xf32>)
793        print(repr(ranked_tensor.type))
794
795        cf32 = ComplexType.get(f32)
796        # CHECK: ComplexType
797        print(type(cf32).__name__)
798        # CHECK: ComplexType(complex<f32>)
799        print(repr(cf32))
800
801        ranked_tensor = tensor.EmptyOp([10, 10], f32).result
802        # CHECK: RankedTensorType
803        print(type(ranked_tensor.type).__name__)
804        # CHECK: RankedTensorType(tensor<10x10xf32>)
805        print(repr(ranked_tensor.type))
806
807        vector = VectorType.get([10, 10], f32)
808        tuple_type = TupleType.get_tuple([f32, vector])
809        # CHECK: TupleType
810        print(type(tuple_type).__name__)
811        # CHECK: TupleType(tuple<f32, vector<10x10xf32>>)
812        print(repr(tuple_type))
813        # CHECK: F32Type(f32)
814        print(repr(tuple_type.get_type(0)))
815        # CHECK: VectorType(vector<10x10xf32>)
816        print(repr(tuple_type.get_type(1)))
817
818        index_type = IndexType.get()
819
820        @func.FuncOp.from_py_func()
821        def default_builder():
822            c0 = arith.ConstantOp(f32, 0.0)
823            unranked_tensor_type = UnrankedTensorType.get(f32)
824            unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result
825            # CHECK: UnrankedTensorType
826            print(type(unranked_tensor.type).__name__)
827            # CHECK: UnrankedTensorType(tensor<*xf32>)
828            print(repr(unranked_tensor.type))
829
830            c10 = arith.ConstantOp(index_type, 10)
831            memref_f32_t = MemRefType.get([10, 10], f32)
832            memref_f32 = memref.AllocOp(memref_f32_t, [c10, c10], []).result
833            # CHECK: MemRefType
834            print(type(memref_f32.type).__name__)
835            # CHECK: MemRefType(memref<10x10xf32>)
836            print(repr(memref_f32.type))
837
838            unranked_memref_t = UnrankedMemRefType.get(f32, Attribute.parse("2"))
839            memref_f32 = memref.AllocOp(unranked_memref_t, [c10, c10], []).result
840            # CHECK: UnrankedMemRefType
841            print(type(memref_f32.type).__name__)
842            # CHECK: UnrankedMemRefType(memref<*xf32, 2>)
843            print(repr(memref_f32.type))
844
845            tuple_type = Operation.parse(
846                f'"test.make_tuple"() : () -> tuple<i32, f32>'
847            ).result
848            # CHECK: TupleType
849            print(type(tuple_type.type).__name__)
850            # CHECK: TupleType(tuple<i32, f32>)
851            print(repr(tuple_type.type))
852
853            return c0, c10
854
855
856# CHECK-LABEL: TEST: testCustomTypeTypeCaster
857# This tests being able to materialize a type from a dialect *and* have
858# the implemented type caster called without explicitly importing the dialect.
859# I.e., we get a transform.OperationType without explicitly importing the transform dialect.
860@run
861def testCustomTypeTypeCaster():
862    with Context() as ctx, Location.unknown():
863        t = Type.parse('!transform.op<"foo.bar">', Context())
864        # CHECK: !transform.op<"foo.bar">
865        print(t)
866        # CHECK: OperationType(!transform.op<"foo.bar">)
867        print(repr(t))
868
869
870# CHECK-LABEL: TEST: testTypeWrappers
871@run
872def testTypeWrappers():
873    def stride(strides, offset=0):
874        return StridedLayoutAttr.get(offset, strides)
875
876    with Context(), Location.unknown():
877        ia = T.i(5)
878        sia = T.si(6)
879        uia = T.ui(7)
880        assert repr(ia) == "IntegerType(i5)"
881        assert repr(sia) == "IntegerType(si6)"
882        assert repr(uia) == "IntegerType(ui7)"
883
884        assert T.i(16) == T.i16()
885        assert T.si(16) == T.si16()
886        assert T.ui(16) == T.ui16()
887
888        c1 = T.complex(T.f16())
889        c2 = T.complex(T.i32())
890        assert repr(c1) == "ComplexType(complex<f16>)"
891        assert repr(c2) == "ComplexType(complex<i32>)"
892
893        vec_1 = T.vector(2, 3, T.f32())
894        vec_2 = T.vector(2, 3, 4, T.f32())
895        assert repr(vec_1) == "VectorType(vector<2x3xf32>)"
896        assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)"
897
898        m1 = T.memref(2, 3, 4, T.f64())
899        assert repr(m1) == "MemRefType(memref<2x3x4xf64>)"
900
901        m2 = T.memref(2, 3, 4, T.f64(), memory_space=1)
902        assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)"
903
904        m3 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13]))
905        assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)"
906
907        m4 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13], 42))
908        assert (
909            repr(m4)
910            == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)"
911        )
912
913        S = ShapedType.get_dynamic_size()
914
915        t1 = T.tensor(S, 3, S, T.f64())
916        assert repr(t1) == "RankedTensorType(tensor<?x3x?xf64>)"
917        ut1 = T.tensor(T.f64())
918        assert repr(ut1) == "UnrankedTensorType(tensor<*xf64>)"
919        t2 = T.tensor(S, 3, S, element_type=T.f64())
920        assert repr(t2) == "RankedTensorType(tensor<?x3x?xf64>)"
921        ut2 = T.tensor(element_type=T.f64())
922        assert repr(ut2) == "UnrankedTensorType(tensor<*xf64>)"
923
924        t3 = T.tensor(S, 3, S, T.f64(), encoding="encoding")
925        assert repr(t3) == 'RankedTensorType(tensor<?x3x?xf64, "encoding">)'
926
927        v = T.vector(3, 3, 3, T.f64())
928        assert repr(v) == "VectorType(vector<3x3x3xf64>)"
929
930        m5 = T.memref(S, 3, S, T.f64())
931        assert repr(m5) == "MemRefType(memref<?x3x?xf64>)"
932        um1 = T.memref(T.f64())
933        assert repr(um1) == "UnrankedMemRefType(memref<*xf64>)"
934        m6 = T.memref(S, 3, S, element_type=T.f64())
935        assert repr(m6) == "MemRefType(memref<?x3x?xf64>)"
936        um2 = T.memref(element_type=T.f64())
937        assert repr(um2) == "UnrankedMemRefType(memref<*xf64>)"
938
939        m7 = T.memref(S, 3, S, T.f64())
940        assert repr(m7) == "MemRefType(memref<?x3x?xf64>)"
941        um3 = T.memref(T.f64())
942        assert repr(um3) == "UnrankedMemRefType(memref<*xf64>)"
943
944        scalable_1 = T.vector(2, 3, T.f32(), scalable=[False, True])
945        scalable_2 = T.vector(2, 3, 4, T.f32(), scalable=[True, False, True])
946        assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)"
947        assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)"
948
949        scalable_3 = T.vector(2, 3, T.f32(), scalable_dims=[1])
950        scalable_4 = T.vector(2, 3, 4, T.f32(), scalable_dims=[0, 2])
951        assert scalable_3 == scalable_1
952        assert scalable_4 == scalable_2
953
954        opaq = T.opaque("scf", "placeholder")
955        assert repr(opaq) == "OpaqueType(!scf.placeholder)"
956
957        tup1 = T.tuple(T.i16(), T.i32(), T.i64())
958        tup2 = T.tuple(T.f16(), T.f32(), T.f64())
959        assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)"
960        assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)"
961
962        func = T.function(
963            inputs=(T.i16(), T.i32(), T.i64()), results=(T.f16(), T.f32(), T.f64())
964        )
965        assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"
966