xref: /llvm-project/mlir/test/python/ir/array_attributes.py (revision 5d6d30edf8b9b2c69215bdbbc651a85e4d0dc4ff)
1# RUN: %PYTHON %s | FileCheck %s
2# Note that this is separate from ir_attributes.py since it depends on numpy,
3# and we may want to disable if not available.
4
5import gc
6from mlir.ir import *
7import numpy as np
8
9def run(f):
10  print("\nTEST:", f.__name__)
11  f()
12  gc.collect()
13  assert Context._get_live_count() == 0
14  return f
15
16################################################################################
17# Tests of the array/buffer .get() factory method on unsupported dtype.
18################################################################################
19
20@run
21def testGetDenseElementsUnsupported():
22  with Context():
23    array = np.array([["hello", "goodbye"]])
24    try:
25      attr = DenseElementsAttr.get(array)
26    except ValueError as e:
27      # CHECK: unimplemented array format conversion from format:
28      print(e)
29
30################################################################################
31# Splats.
32################################################################################
33
34# CHECK-LABEL: TEST: testGetDenseElementsSplatInt
35@run
36def testGetDenseElementsSplatInt():
37  with Context(), Location.unknown():
38    t = IntegerType.get_signless(32)
39    element = IntegerAttr.get(t, 555)
40    shaped_type = RankedTensorType.get((2, 3, 4), t)
41    attr = DenseElementsAttr.get_splat(shaped_type, element)
42    # CHECK: dense<555> : tensor<2x3x4xi32>
43    print(attr)
44    # CHECK: is_splat: True
45    print("is_splat:", attr.is_splat)
46
47
48# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
49@run
50def testGetDenseElementsSplatFloat():
51  with Context(), Location.unknown():
52    t = F32Type.get()
53    element = FloatAttr.get(t, 1.2)
54    shaped_type = RankedTensorType.get((2, 3, 4), t)
55    attr = DenseElementsAttr.get_splat(shaped_type, element)
56    # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
57    print(attr)
58
59
60# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
61@run
62def testGetDenseElementsSplatErrors():
63  with Context(), Location.unknown():
64    t = F32Type.get()
65    other_t = F64Type.get()
66    element = FloatAttr.get(t, 1.2)
67    other_element = FloatAttr.get(other_t, 1.2)
68    shaped_type = RankedTensorType.get((2, 3, 4), t)
69    dynamic_shaped_type = UnrankedTensorType.get(t)
70    non_shaped_type = t
71
72    try:
73      attr = DenseElementsAttr.get_splat(non_shaped_type, element)
74    except ValueError as e:
75      # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
76      print(e)
77
78    try:
79      attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element)
80    except ValueError as e:
81      # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
82      print(e)
83
84    try:
85      attr = DenseElementsAttr.get_splat(shaped_type, other_element)
86    except ValueError as e:
87      # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
88      print(e)
89
90
91# CHECK-LABEL: TEST: testRepeatedValuesSplat
92@run
93def testRepeatedValuesSplat():
94  with Context():
95    array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32)
96    attr = DenseElementsAttr.get(array)
97    # CHECK: dense<1.000000e+00> : tensor<2x3xf32>
98    print(attr)
99    # CHECK: is_splat: True
100    print("is_splat:", attr.is_splat)
101    # CHECK: ()
102    print(np.array(attr))
103
104
105# CHECK-LABEL: TEST: testNonSplat
106@run
107def testNonSplat():
108  with Context():
109    array = np.array([2.0, 1.0, 1.0], dtype=np.float32)
110    attr = DenseElementsAttr.get(array)
111    # CHECK: is_splat: False
112    print("is_splat:", attr.is_splat)
113
114
115################################################################################
116# Tests of the array/buffer .get() factory method, in all of its permutations.
117################################################################################
118
119### explicitly provided types
120
121@run
122def testGetDenseElementsBF16():
123  with Context():
124    array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16)
125    attr = DenseElementsAttr.get(array, type=BF16Type.get())
126    # Note: These values don't mean much since just bit-casting. But they
127    # shouldn't change.
128    # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16>
129    print(attr)
130
131@run
132def testGetDenseElementsInteger4():
133  with Context():
134    array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8)
135    attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
136    # Note: These values don't mean much since just bit-casting. But they
137    # shouldn't change.
138    # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4>
139    print(attr)
140
141
142@run
143def testGetDenseElementsBool():
144  with Context():
145    bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_)
146    array = np.packbits(bool_array, axis=None, bitorder="little")
147    attr = DenseElementsAttr.get(
148        array, type=IntegerType.get_signless(1), shape=bool_array.shape)
149    # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1>
150    print(attr)
151
152
153@run
154def testGetDenseElementsBoolSplat():
155  with Context():
156    zero = np.array(0, dtype=np.uint8)
157    one = np.array(255, dtype=np.uint8)
158    print(one)
159    # CHECK: dense<false> : tensor<4x2x5xi1>
160    print(DenseElementsAttr.get(
161        zero, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
162    # CHECK: dense<true> : tensor<4x2x5xi1>
163    print(DenseElementsAttr.get(
164        one, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
165
166
167### float and double arrays.
168
169# CHECK-LABEL: TEST: testGetDenseElementsF16
170@run
171def testGetDenseElementsF16():
172  with Context():
173    array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16)
174    attr = DenseElementsAttr.get(array)
175    # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16>
176    print(attr)
177    # CHECK: {{\[}}[ 2. 4. 8.]
178    # CHECK: {{\[}}16. 32. 64.]]
179    print(np.array(attr))
180
181
182# CHECK-LABEL: TEST: testGetDenseElementsF32
183@run
184def testGetDenseElementsF32():
185  with Context():
186    array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
187    attr = DenseElementsAttr.get(array)
188    # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
189    print(attr)
190    # CHECK: {{\[}}[1.1 2.2 3.3]
191    # CHECK: {{\[}}4.4 5.5 6.6]]
192    print(np.array(attr))
193
194
195# CHECK-LABEL: TEST: testGetDenseElementsF64
196@run
197def testGetDenseElementsF64():
198  with Context():
199    array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
200    attr = DenseElementsAttr.get(array)
201    # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
202    print(attr)
203    # CHECK: {{\[}}[1.1 2.2 3.3]
204    # CHECK: {{\[}}4.4 5.5 6.6]]
205    print(np.array(attr))
206
207
208### 16 bit integer arrays
209# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
210@run
211def testGetDenseElementsI16Signless():
212  with Context():
213    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
214    attr = DenseElementsAttr.get(array)
215    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
216    print(attr)
217    # CHECK: {{\[}}[1 2 3]
218    # CHECK: {{\[}}4 5 6]]
219    print(np.array(attr))
220
221
222# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless
223@run
224def testGetDenseElementsUI16Signless():
225  with Context():
226    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
227    attr = DenseElementsAttr.get(array)
228    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
229    print(attr)
230    # CHECK: {{\[}}[1 2 3]
231    # CHECK: {{\[}}4 5 6]]
232    print(np.array(attr))
233
234
235# CHECK-LABEL: TEST: testGetDenseElementsI16
236@run
237def testGetDenseElementsI16():
238  with Context():
239    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
240    attr = DenseElementsAttr.get(array, signless=False)
241    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16>
242    print(attr)
243    # CHECK: {{\[}}[1 2 3]
244    # CHECK: {{\[}}4 5 6]]
245    print(np.array(attr))
246
247
248# CHECK-LABEL: TEST: testGetDenseElementsUI16
249@run
250def testGetDenseElementsUI16():
251  with Context():
252    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
253    attr = DenseElementsAttr.get(array, signless=False)
254    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16>
255    print(attr)
256    # CHECK: {{\[}}[1 2 3]
257    # CHECK: {{\[}}4 5 6]]
258    print(np.array(attr))
259
260### 32 bit integer arrays
261# CHECK-LABEL: TEST: testGetDenseElementsI32Signless
262@run
263def testGetDenseElementsI32Signless():
264  with Context():
265    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
266    attr = DenseElementsAttr.get(array)
267    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
268    print(attr)
269    # CHECK: {{\[}}[1 2 3]
270    # CHECK: {{\[}}4 5 6]]
271    print(np.array(attr))
272
273
274# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
275@run
276def testGetDenseElementsUI32Signless():
277  with Context():
278    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
279    attr = DenseElementsAttr.get(array)
280    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
281    print(attr)
282    # CHECK: {{\[}}[1 2 3]
283    # CHECK: {{\[}}4 5 6]]
284    print(np.array(attr))
285
286
287# CHECK-LABEL: TEST: testGetDenseElementsI32
288@run
289def testGetDenseElementsI32():
290  with Context():
291    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
292    attr = DenseElementsAttr.get(array, signless=False)
293    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
294    print(attr)
295    # CHECK: {{\[}}[1 2 3]
296    # CHECK: {{\[}}4 5 6]]
297    print(np.array(attr))
298
299
300# CHECK-LABEL: TEST: testGetDenseElementsUI32
301@run
302def testGetDenseElementsUI32():
303  with Context():
304    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
305    attr = DenseElementsAttr.get(array, signless=False)
306    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
307    print(attr)
308    # CHECK: {{\[}}[1 2 3]
309    # CHECK: {{\[}}4 5 6]]
310    print(np.array(attr))
311
312
313## 64bit integer arrays
314# CHECK-LABEL: TEST: testGetDenseElementsI64Signless
315@run
316def testGetDenseElementsI64Signless():
317  with Context():
318    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
319    attr = DenseElementsAttr.get(array)
320    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
321    print(attr)
322    # CHECK: {{\[}}[1 2 3]
323    # CHECK: {{\[}}4 5 6]]
324    print(np.array(attr))
325
326
327# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
328@run
329def testGetDenseElementsUI64Signless():
330  with Context():
331    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
332    attr = DenseElementsAttr.get(array)
333    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
334    print(attr)
335    # CHECK: {{\[}}[1 2 3]
336    # CHECK: {{\[}}4 5 6]]
337    print(np.array(attr))
338
339
340# CHECK-LABEL: TEST: testGetDenseElementsI64
341@run
342def testGetDenseElementsI64():
343  with Context():
344    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
345    attr = DenseElementsAttr.get(array, signless=False)
346    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
347    print(attr)
348    # CHECK: {{\[}}[1 2 3]
349    # CHECK: {{\[}}4 5 6]]
350    print(np.array(attr))
351
352
353# CHECK-LABEL: TEST: testGetDenseElementsUI64
354@run
355def testGetDenseElementsUI64():
356  with Context():
357    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
358    attr = DenseElementsAttr.get(array, signless=False)
359    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
360    print(attr)
361    # CHECK: {{\[}}[1 2 3]
362    # CHECK: {{\[}}4 5 6]]
363    print(np.array(attr))
364
365