xref: /llvm-project/mlir/test/python/ir/array_attributes.py (revision c5f445d143485f898353df6d422eea1dea22c7a8)
1# RUN: %PYTHON %s | FileCheck %s
2# Note that this is separate from ir_attributes.py since it depends on numpy,
3# and we may want to disable if not available.
4
5import gc
6from mlir.ir import *
7import numpy as np
8
9def run(f):
10  print("\nTEST:", f.__name__)
11  f()
12  gc.collect()
13  assert Context._get_live_count() == 0
14  return f
15
16################################################################################
17# Tests of the array/buffer .get() factory method on unsupported dtype.
18################################################################################
19
20@run
21def testGetDenseElementsUnsupported():
22  with Context():
23    array = np.array([["hello", "goodbye"]])
24    try:
25      attr = DenseElementsAttr.get(array)
26    except ValueError as e:
27      # CHECK: unimplemented array format conversion from format:
28      print(e)
29
30################################################################################
31# Splats.
32################################################################################
33
34# CHECK-LABEL: TEST: testGetDenseElementsSplatInt
35@run
36def testGetDenseElementsSplatInt():
37  with Context(), Location.unknown():
38    t = IntegerType.get_signless(32)
39    element = IntegerAttr.get(t, 555)
40    shaped_type = RankedTensorType.get((2, 3, 4), t)
41    attr = DenseElementsAttr.get_splat(shaped_type, element)
42    # CHECK: dense<555> : tensor<2x3x4xi32>
43    print(attr)
44    # CHECK: is_splat: True
45    print("is_splat:", attr.is_splat)
46
47
48# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
49@run
50def testGetDenseElementsSplatFloat():
51  with Context(), Location.unknown():
52    t = F32Type.get()
53    element = FloatAttr.get(t, 1.2)
54    shaped_type = RankedTensorType.get((2, 3, 4), t)
55    attr = DenseElementsAttr.get_splat(shaped_type, element)
56    # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
57    print(attr)
58
59
60# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
61@run
62def testGetDenseElementsSplatErrors():
63  with Context(), Location.unknown():
64    t = F32Type.get()
65    other_t = F64Type.get()
66    element = FloatAttr.get(t, 1.2)
67    other_element = FloatAttr.get(other_t, 1.2)
68    shaped_type = RankedTensorType.get((2, 3, 4), t)
69    dynamic_shaped_type = UnrankedTensorType.get(t)
70    non_shaped_type = t
71
72    try:
73      attr = DenseElementsAttr.get_splat(non_shaped_type, element)
74    except ValueError as e:
75      # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
76      print(e)
77
78    try:
79      attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element)
80    except ValueError as e:
81      # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
82      print(e)
83
84    try:
85      attr = DenseElementsAttr.get_splat(shaped_type, other_element)
86    except ValueError as e:
87      # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
88      print(e)
89
90
91# CHECK-LABEL: TEST: testRepeatedValuesSplat
92@run
93def testRepeatedValuesSplat():
94  with Context():
95    array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32)
96    attr = DenseElementsAttr.get(array)
97    # CHECK: dense<1.000000e+00> : tensor<2x3xf32>
98    print(attr)
99    # CHECK: is_splat: True
100    print("is_splat:", attr.is_splat)
101    # TODO: Re-enable this once a solution is found to raising an exception
102    # from buffer protocol.
103    # Reported as https://github.com/pybind/pybind11/issues/3336
104    # print(np.array(attr))
105
106
107# CHECK-LABEL: TEST: testNonSplat
108@run
109def testNonSplat():
110  with Context():
111    array = np.array([2.0, 1.0, 1.0], dtype=np.float32)
112    attr = DenseElementsAttr.get(array)
113    # CHECK: is_splat: False
114    print("is_splat:", attr.is_splat)
115
116
117################################################################################
118# Tests of the array/buffer .get() factory method, in all of its permutations.
119################################################################################
120
121### explicitly provided types
122
123@run
124def testGetDenseElementsBF16():
125  with Context():
126    array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16)
127    attr = DenseElementsAttr.get(array, type=BF16Type.get())
128    # Note: These values don't mean much since just bit-casting. But they
129    # shouldn't change.
130    # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16>
131    print(attr)
132
133@run
134def testGetDenseElementsInteger4():
135  with Context():
136    array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8)
137    attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
138    # Note: These values don't mean much since just bit-casting. But they
139    # shouldn't change.
140    # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4>
141    print(attr)
142
143
144@run
145def testGetDenseElementsBool():
146  with Context():
147    bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_)
148    array = np.packbits(bool_array, axis=None, bitorder="little")
149    attr = DenseElementsAttr.get(
150        array, type=IntegerType.get_signless(1), shape=bool_array.shape)
151    # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1>
152    print(attr)
153
154
155@run
156def testGetDenseElementsBoolSplat():
157  with Context():
158    zero = np.array(0, dtype=np.uint8)
159    one = np.array(255, dtype=np.uint8)
160    print(one)
161    # CHECK: dense<false> : tensor<4x2x5xi1>
162    print(DenseElementsAttr.get(
163        zero, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
164    # CHECK: dense<true> : tensor<4x2x5xi1>
165    print(DenseElementsAttr.get(
166        one, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
167
168
169### float and double arrays.
170
171# CHECK-LABEL: TEST: testGetDenseElementsF16
172@run
173def testGetDenseElementsF16():
174  with Context():
175    array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16)
176    attr = DenseElementsAttr.get(array)
177    # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16>
178    print(attr)
179    # CHECK: {{\[}}[ 2. 4. 8.]
180    # CHECK: {{\[}}16. 32. 64.]]
181    print(np.array(attr))
182
183
184# CHECK-LABEL: TEST: testGetDenseElementsF32
185@run
186def testGetDenseElementsF32():
187  with Context():
188    array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
189    attr = DenseElementsAttr.get(array)
190    # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
191    print(attr)
192    # CHECK: {{\[}}[1.1 2.2 3.3]
193    # CHECK: {{\[}}4.4 5.5 6.6]]
194    print(np.array(attr))
195
196
197# CHECK-LABEL: TEST: testGetDenseElementsF64
198@run
199def testGetDenseElementsF64():
200  with Context():
201    array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
202    attr = DenseElementsAttr.get(array)
203    # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
204    print(attr)
205    # CHECK: {{\[}}[1.1 2.2 3.3]
206    # CHECK: {{\[}}4.4 5.5 6.6]]
207    print(np.array(attr))
208
209
210### 16 bit integer arrays
211# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
212@run
213def testGetDenseElementsI16Signless():
214  with Context():
215    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
216    attr = DenseElementsAttr.get(array)
217    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
218    print(attr)
219    # CHECK: {{\[}}[1 2 3]
220    # CHECK: {{\[}}4 5 6]]
221    print(np.array(attr))
222
223
224# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless
225@run
226def testGetDenseElementsUI16Signless():
227  with Context():
228    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
229    attr = DenseElementsAttr.get(array)
230    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
231    print(attr)
232    # CHECK: {{\[}}[1 2 3]
233    # CHECK: {{\[}}4 5 6]]
234    print(np.array(attr))
235
236
237# CHECK-LABEL: TEST: testGetDenseElementsI16
238@run
239def testGetDenseElementsI16():
240  with Context():
241    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
242    attr = DenseElementsAttr.get(array, signless=False)
243    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16>
244    print(attr)
245    # CHECK: {{\[}}[1 2 3]
246    # CHECK: {{\[}}4 5 6]]
247    print(np.array(attr))
248
249
250# CHECK-LABEL: TEST: testGetDenseElementsUI16
251@run
252def testGetDenseElementsUI16():
253  with Context():
254    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
255    attr = DenseElementsAttr.get(array, signless=False)
256    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16>
257    print(attr)
258    # CHECK: {{\[}}[1 2 3]
259    # CHECK: {{\[}}4 5 6]]
260    print(np.array(attr))
261
262### 32 bit integer arrays
263# CHECK-LABEL: TEST: testGetDenseElementsI32Signless
264@run
265def testGetDenseElementsI32Signless():
266  with Context():
267    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
268    attr = DenseElementsAttr.get(array)
269    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
270    print(attr)
271    # CHECK: {{\[}}[1 2 3]
272    # CHECK: {{\[}}4 5 6]]
273    print(np.array(attr))
274
275
276# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
277@run
278def testGetDenseElementsUI32Signless():
279  with Context():
280    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
281    attr = DenseElementsAttr.get(array)
282    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
283    print(attr)
284    # CHECK: {{\[}}[1 2 3]
285    # CHECK: {{\[}}4 5 6]]
286    print(np.array(attr))
287
288
289# CHECK-LABEL: TEST: testGetDenseElementsI32
290@run
291def testGetDenseElementsI32():
292  with Context():
293    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
294    attr = DenseElementsAttr.get(array, signless=False)
295    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
296    print(attr)
297    # CHECK: {{\[}}[1 2 3]
298    # CHECK: {{\[}}4 5 6]]
299    print(np.array(attr))
300
301
302# CHECK-LABEL: TEST: testGetDenseElementsUI32
303@run
304def testGetDenseElementsUI32():
305  with Context():
306    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
307    attr = DenseElementsAttr.get(array, signless=False)
308    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
309    print(attr)
310    # CHECK: {{\[}}[1 2 3]
311    # CHECK: {{\[}}4 5 6]]
312    print(np.array(attr))
313
314
315## 64bit integer arrays
316# CHECK-LABEL: TEST: testGetDenseElementsI64Signless
317@run
318def testGetDenseElementsI64Signless():
319  with Context():
320    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
321    attr = DenseElementsAttr.get(array)
322    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
323    print(attr)
324    # CHECK: {{\[}}[1 2 3]
325    # CHECK: {{\[}}4 5 6]]
326    print(np.array(attr))
327
328
329# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
330@run
331def testGetDenseElementsUI64Signless():
332  with Context():
333    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
334    attr = DenseElementsAttr.get(array)
335    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
336    print(attr)
337    # CHECK: {{\[}}[1 2 3]
338    # CHECK: {{\[}}4 5 6]]
339    print(np.array(attr))
340
341
342# CHECK-LABEL: TEST: testGetDenseElementsI64
343@run
344def testGetDenseElementsI64():
345  with Context():
346    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
347    attr = DenseElementsAttr.get(array, signless=False)
348    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
349    print(attr)
350    # CHECK: {{\[}}[1 2 3]
351    # CHECK: {{\[}}4 5 6]]
352    print(np.array(attr))
353
354
355# CHECK-LABEL: TEST: testGetDenseElementsUI64
356@run
357def testGetDenseElementsUI64():
358  with Context():
359    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
360    attr = DenseElementsAttr.get(array, signless=False)
361    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
362    print(attr)
363    # CHECK: {{\[}}[1 2 3]
364    # CHECK: {{\[}}4 5 6]]
365    print(np.array(attr))
366
367