xref: /llvm-project/mlir/test/python/ir/array_attributes.py (revision 71a254543d44a943dfe8790abc60795b87173f0b)
1# RUN: %PYTHON %s | FileCheck %s
2# Note that this is separate from ir_attributes.py since it depends on numpy,
3# and we may want to disable if not available.
4
5import gc
6from mlir.ir import *
7import numpy as np
8
9
10def run(f):
11    print("\nTEST:", f.__name__)
12    f()
13    gc.collect()
14    assert Context._get_live_count() == 0
15    return f
16
17
18################################################################################
19# Tests of the array/buffer .get() factory method on unsupported dtype.
20################################################################################
21
22
23@run
24def testGetDenseElementsUnsupported():
25    with Context():
26        array = np.array([["hello", "goodbye"]])
27        try:
28            attr = DenseElementsAttr.get(array)
29        except ValueError as e:
30            # CHECK: unimplemented array format conversion from format:
31            print(e)
32
33# CHECK-LABEL: TEST: testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided
34@run
35def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
36    with Context():
37        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
38        # datetime64 specifically isn't important: it's just a 64-bit type that
39        # doesn't have a format under the Python buffer protocol. A more
40        # realistic example would be a NumPy extension type like the bfloat16
41        # type from the ml_dtypes package, which isn't a dependency of this
42        # test.
43        attr = DenseElementsAttr.get(array.view(np.datetime64),
44                                     type=IntegerType.get_signless(64))
45        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
46        print(attr)
47        # CHECK: {{\[}}[1 2 3]
48        # CHECK: {{\[}}4 5 6]]
49        print(np.array(attr))
50
51
52################################################################################
53# Splats.
54################################################################################
55
56# CHECK-LABEL: TEST: testGetDenseElementsSplatInt
57@run
58def testGetDenseElementsSplatInt():
59    with Context(), Location.unknown():
60        t = IntegerType.get_signless(32)
61        element = IntegerAttr.get(t, 555)
62        shaped_type = RankedTensorType.get((2, 3, 4), t)
63        attr = DenseElementsAttr.get_splat(shaped_type, element)
64        # CHECK: dense<555> : tensor<2x3x4xi32>
65        print(attr)
66        # CHECK: is_splat: True
67        print("is_splat:", attr.is_splat)
68
69        # CHECK: splat_value: IntegerAttr(555 : i32)
70        splat_value = attr.get_splat_value()
71        print("splat_value:", repr(splat_value))
72        assert splat_value == element
73
74
75# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
76@run
77def testGetDenseElementsSplatFloat():
78    with Context(), Location.unknown():
79        t = F32Type.get()
80        element = FloatAttr.get(t, 1.2)
81        shaped_type = RankedTensorType.get((2, 3, 4), t)
82        attr = DenseElementsAttr.get_splat(shaped_type, element)
83        # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
84        print(attr)
85        assert attr.get_splat_value() == element
86
87
88# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
89@run
90def testGetDenseElementsSplatErrors():
91    with Context(), Location.unknown():
92        t = F32Type.get()
93        other_t = F64Type.get()
94        element = FloatAttr.get(t, 1.2)
95        other_element = FloatAttr.get(other_t, 1.2)
96        shaped_type = RankedTensorType.get((2, 3, 4), t)
97        dynamic_shaped_type = UnrankedTensorType.get(t)
98        non_shaped_type = t
99
100        try:
101            attr = DenseElementsAttr.get_splat(non_shaped_type, element)
102        except ValueError as e:
103            # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
104            print(e)
105
106        try:
107            attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element)
108        except ValueError as e:
109            # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
110            print(e)
111
112        try:
113            attr = DenseElementsAttr.get_splat(shaped_type, other_element)
114        except ValueError as e:
115            # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
116            print(e)
117
118
119# CHECK-LABEL: TEST: testRepeatedValuesSplat
120@run
121def testRepeatedValuesSplat():
122    with Context():
123        array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32)
124        attr = DenseElementsAttr.get(array)
125        # CHECK: dense<1.000000e+00> : tensor<2x3xf32>
126        print(attr)
127        # CHECK: is_splat: True
128        print("is_splat:", attr.is_splat)
129        # CHECK{LITERAL}: [[1. 1. 1.]
130        # CHECK{LITERAL}:  [1. 1. 1.]]
131        print(np.array(attr))
132
133
134# CHECK-LABEL: TEST: testNonSplat
135@run
136def testNonSplat():
137    with Context():
138        array = np.array([2.0, 1.0, 1.0], dtype=np.float32)
139        attr = DenseElementsAttr.get(array)
140        # CHECK: is_splat: False
141        print("is_splat:", attr.is_splat)
142
143
144################################################################################
145# Tests of the array/buffer .get() factory method, in all of its permutations.
146################################################################################
147
148### explicitly provided types
149
150
151@run
152def testGetDenseElementsBF16():
153    with Context():
154        array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16)
155        attr = DenseElementsAttr.get(array, type=BF16Type.get())
156        # Note: These values don't mean much since just bit-casting. But they
157        # shouldn't change.
158        # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16>
159        print(attr)
160
161
162@run
163def testGetDenseElementsInteger4():
164    with Context():
165        array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8)
166        attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
167        # Note: These values don't mean much since just bit-casting. But they
168        # shouldn't change.
169        # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4>
170        print(attr)
171
172
173@run
174def testGetDenseElementsBool():
175    with Context():
176        bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_)
177        array = np.packbits(bool_array, axis=None, bitorder="little")
178        attr = DenseElementsAttr.get(
179            array, type=IntegerType.get_signless(1), shape=bool_array.shape
180        )
181        # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1>
182        print(attr)
183
184
185@run
186def testGetDenseElementsBoolSplat():
187    with Context():
188        zero = np.array(0, dtype=np.uint8)
189        one = np.array(255, dtype=np.uint8)
190        print(one)
191        # CHECK: dense<false> : tensor<4x2x5xi1>
192        print(
193            DenseElementsAttr.get(
194                zero, type=IntegerType.get_signless(1), shape=(4, 2, 5)
195            )
196        )
197        # CHECK: dense<true> : tensor<4x2x5xi1>
198        print(
199            DenseElementsAttr.get(
200                one, type=IntegerType.get_signless(1), shape=(4, 2, 5)
201            )
202        )
203
204
205### float and double arrays.
206
207# CHECK-LABEL: TEST: testGetDenseElementsF16
208@run
209def testGetDenseElementsF16():
210    with Context():
211        array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16)
212        attr = DenseElementsAttr.get(array)
213        # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16>
214        print(attr)
215        # CHECK: {{\[}}[ 2. 4. 8.]
216        # CHECK: {{\[}}16. 32. 64.]]
217        print(np.array(attr))
218
219
220# CHECK-LABEL: TEST: testGetDenseElementsF32
221@run
222def testGetDenseElementsF32():
223    with Context():
224        array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
225        attr = DenseElementsAttr.get(array)
226        # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
227        print(attr)
228        # CHECK: {{\[}}[1.1 2.2 3.3]
229        # CHECK: {{\[}}4.4 5.5 6.6]]
230        print(np.array(attr))
231
232
233# CHECK-LABEL: TEST: testGetDenseElementsF64
234@run
235def testGetDenseElementsF64():
236    with Context():
237        array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
238        attr = DenseElementsAttr.get(array)
239        # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
240        print(attr)
241        # CHECK: {{\[}}[1.1 2.2 3.3]
242        # CHECK: {{\[}}4.4 5.5 6.6]]
243        print(np.array(attr))
244
245
246### 16 bit integer arrays
247# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
248@run
249def testGetDenseElementsI16Signless():
250    with Context():
251        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
252        attr = DenseElementsAttr.get(array)
253        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
254        print(attr)
255        # CHECK: {{\[}}[1 2 3]
256        # CHECK: {{\[}}4 5 6]]
257        print(np.array(attr))
258
259
260# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless
261@run
262def testGetDenseElementsUI16Signless():
263    with Context():
264        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
265        attr = DenseElementsAttr.get(array)
266        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
267        print(attr)
268        # CHECK: {{\[}}[1 2 3]
269        # CHECK: {{\[}}4 5 6]]
270        print(np.array(attr))
271
272
273# CHECK-LABEL: TEST: testGetDenseElementsI16
274@run
275def testGetDenseElementsI16():
276    with Context():
277        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
278        attr = DenseElementsAttr.get(array, signless=False)
279        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16>
280        print(attr)
281        # CHECK: {{\[}}[1 2 3]
282        # CHECK: {{\[}}4 5 6]]
283        print(np.array(attr))
284
285
286# CHECK-LABEL: TEST: testGetDenseElementsUI16
287@run
288def testGetDenseElementsUI16():
289    with Context():
290        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
291        attr = DenseElementsAttr.get(array, signless=False)
292        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16>
293        print(attr)
294        # CHECK: {{\[}}[1 2 3]
295        # CHECK: {{\[}}4 5 6]]
296        print(np.array(attr))
297
298
299### 32 bit integer arrays
300# CHECK-LABEL: TEST: testGetDenseElementsI32Signless
301@run
302def testGetDenseElementsI32Signless():
303    with Context():
304        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
305        attr = DenseElementsAttr.get(array)
306        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
307        print(attr)
308        # CHECK: {{\[}}[1 2 3]
309        # CHECK: {{\[}}4 5 6]]
310        print(np.array(attr))
311
312
313# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
314@run
315def testGetDenseElementsUI32Signless():
316    with Context():
317        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
318        attr = DenseElementsAttr.get(array)
319        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
320        print(attr)
321        # CHECK: {{\[}}[1 2 3]
322        # CHECK: {{\[}}4 5 6]]
323        print(np.array(attr))
324
325
326# CHECK-LABEL: TEST: testGetDenseElementsI32
327@run
328def testGetDenseElementsI32():
329    with Context():
330        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
331        attr = DenseElementsAttr.get(array, signless=False)
332        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
333        print(attr)
334        # CHECK: {{\[}}[1 2 3]
335        # CHECK: {{\[}}4 5 6]]
336        print(np.array(attr))
337
338
339# CHECK-LABEL: TEST: testGetDenseElementsUI32
340@run
341def testGetDenseElementsUI32():
342    with Context():
343        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
344        attr = DenseElementsAttr.get(array, signless=False)
345        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
346        print(attr)
347        # CHECK: {{\[}}[1 2 3]
348        # CHECK: {{\[}}4 5 6]]
349        print(np.array(attr))
350
351
352## 64bit integer arrays
353# CHECK-LABEL: TEST: testGetDenseElementsI64Signless
354@run
355def testGetDenseElementsI64Signless():
356    with Context():
357        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
358        attr = DenseElementsAttr.get(array)
359        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
360        print(attr)
361        # CHECK: {{\[}}[1 2 3]
362        # CHECK: {{\[}}4 5 6]]
363        print(np.array(attr))
364
365
366# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
367@run
368def testGetDenseElementsUI64Signless():
369    with Context():
370        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
371        attr = DenseElementsAttr.get(array)
372        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
373        print(attr)
374        # CHECK: {{\[}}[1 2 3]
375        # CHECK: {{\[}}4 5 6]]
376        print(np.array(attr))
377
378
379# CHECK-LABEL: TEST: testGetDenseElementsI64
380@run
381def testGetDenseElementsI64():
382    with Context():
383        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
384        attr = DenseElementsAttr.get(array, signless=False)
385        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
386        print(attr)
387        # CHECK: {{\[}}[1 2 3]
388        # CHECK: {{\[}}4 5 6]]
389        print(np.array(attr))
390
391
392# CHECK-LABEL: TEST: testGetDenseElementsUI64
393@run
394def testGetDenseElementsUI64():
395    with Context():
396        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
397        attr = DenseElementsAttr.get(array, signless=False)
398        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
399        print(attr)
400        # CHECK: {{\[}}[1 2 3]
401        # CHECK: {{\[}}4 5 6]]
402        print(np.array(attr))
403
404
405# CHECK-LABEL: TEST: testGetDenseElementsIndex
406@run
407def testGetDenseElementsIndex():
408    with Context(), Location.unknown():
409        idx_type = IndexType.get()
410        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
411        attr = DenseElementsAttr.get(array, type=idx_type)
412        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex>
413        print(attr)
414        arr = np.array(attr)
415        # CHECK: {{\[}}[1 2 3]
416        # CHECK: {{\[}}4 5 6]]
417        print(arr)
418        # CHECK: True
419        print(arr.dtype == np.int64)
420