xref: /llvm-project/mlir/test/python/ir/array_attributes.py (revision f66cd9e9556a53142a26a5c21a72e21f1579217c)
1# RUN: %PYTHON %s | FileCheck %s
2# Note that this is separate from ir_attributes.py since it depends on numpy,
3# and we may want to disable if not available.
4
5import gc
6from mlir.ir import *
7import numpy as np
8import weakref
9
10
11def run(f):
12    print("\nTEST:", f.__name__)
13    f()
14    gc.collect()
15    assert Context._get_live_count() == 0
16    return f
17
18
19################################################################################
20# Tests of the array/buffer .get() factory method on unsupported dtype.
21################################################################################
22
23
24@run
25def testGetDenseElementsUnsupported():
26    with Context():
27        array = np.array([["hello", "goodbye"]])
28        try:
29            attr = DenseElementsAttr.get(array)
30        except ValueError as e:
31            # CHECK: unimplemented array format conversion from format:
32            print(e)
33
34# CHECK-LABEL: TEST: testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided
35@run
36def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
37    with Context():
38        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
39        # datetime64 specifically isn't important: it's just a 64-bit type that
40        # doesn't have a format under the Python buffer protocol. A more
41        # realistic example would be a NumPy extension type like the bfloat16
42        # type from the ml_dtypes package, which isn't a dependency of this
43        # test.
44        attr = DenseElementsAttr.get(array.view(np.datetime64),
45                                     type=IntegerType.get_signless(64))
46        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
47        print(attr)
48        # CHECK: {{\[}}[1 2 3]
49        # CHECK: {{\[}}4 5 6]]
50        print(np.array(attr))
51
52
53################################################################################
54# Splats.
55################################################################################
56
57# CHECK-LABEL: TEST: testGetDenseElementsSplatInt
58@run
59def testGetDenseElementsSplatInt():
60    with Context(), Location.unknown():
61        t = IntegerType.get_signless(32)
62        element = IntegerAttr.get(t, 555)
63        shaped_type = RankedTensorType.get((2, 3, 4), t)
64        attr = DenseElementsAttr.get_splat(shaped_type, element)
65        # CHECK: dense<555> : tensor<2x3x4xi32>
66        print(attr)
67        # CHECK: is_splat: True
68        print("is_splat:", attr.is_splat)
69
70        # CHECK: splat_value: IntegerAttr(555 : i32)
71        splat_value = attr.get_splat_value()
72        print("splat_value:", repr(splat_value))
73        assert splat_value == element
74
75
76# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
77@run
78def testGetDenseElementsSplatFloat():
79    with Context(), Location.unknown():
80        t = F32Type.get()
81        element = FloatAttr.get(t, 1.2)
82        shaped_type = RankedTensorType.get((2, 3, 4), t)
83        attr = DenseElementsAttr.get_splat(shaped_type, element)
84        # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
85        print(attr)
86        assert attr.get_splat_value() == element
87
88
89# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
90@run
91def testGetDenseElementsSplatErrors():
92    with Context(), Location.unknown():
93        t = F32Type.get()
94        other_t = F64Type.get()
95        element = FloatAttr.get(t, 1.2)
96        other_element = FloatAttr.get(other_t, 1.2)
97        shaped_type = RankedTensorType.get((2, 3, 4), t)
98        dynamic_shaped_type = UnrankedTensorType.get(t)
99        non_shaped_type = t
100
101        try:
102            attr = DenseElementsAttr.get_splat(non_shaped_type, element)
103        except ValueError as e:
104            # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
105            print(e)
106
107        try:
108            attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element)
109        except ValueError as e:
110            # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
111            print(e)
112
113        try:
114            attr = DenseElementsAttr.get_splat(shaped_type, other_element)
115        except ValueError as e:
116            # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
117            print(e)
118
119
120# CHECK-LABEL: TEST: testRepeatedValuesSplat
121@run
122def testRepeatedValuesSplat():
123    with Context():
124        array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32)
125        attr = DenseElementsAttr.get(array)
126        # CHECK: dense<1.000000e+00> : tensor<2x3xf32>
127        print(attr)
128        # CHECK: is_splat: True
129        print("is_splat:", attr.is_splat)
130        # CHECK{LITERAL}: [[1. 1. 1.]
131        # CHECK{LITERAL}:  [1. 1. 1.]]
132        print(np.array(attr))
133
134
135# CHECK-LABEL: TEST: testNonSplat
136@run
137def testNonSplat():
138    with Context():
139        array = np.array([2.0, 1.0, 1.0], dtype=np.float32)
140        attr = DenseElementsAttr.get(array)
141        # CHECK: is_splat: False
142        print("is_splat:", attr.is_splat)
143
144
145################################################################################
146# Tests of the array/buffer .get() factory method, in all of its permutations.
147################################################################################
148
149### explicitly provided types
150
151
152@run
153def testGetDenseElementsBF16():
154    with Context():
155        array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16)
156        attr = DenseElementsAttr.get(array, type=BF16Type.get())
157        # Note: These values don't mean much since just bit-casting. But they
158        # shouldn't change.
159        # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16>
160        print(attr)
161
162
163@run
164def testGetDenseElementsInteger4():
165    with Context():
166        array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.int8)
167        attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
168        # Note: These values don't mean much since just bit-casting. But they
169        # shouldn't change.
170        # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4>
171        print(attr)
172
173
174@run
175def testGetDenseElementsBool():
176    with Context():
177        bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_)
178        array = np.packbits(bool_array, axis=None, bitorder="little")
179        attr = DenseElementsAttr.get(
180            array, type=IntegerType.get_signless(1), shape=bool_array.shape
181        )
182        # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1>
183        print(attr)
184
185
186@run
187def testGetDenseElementsBoolSplat():
188    with Context():
189        zero = np.array(0, dtype=np.uint8)
190        one = np.array(255, dtype=np.uint8)
191        print(one)
192        # CHECK: dense<false> : tensor<4x2x5xi1>
193        print(
194            DenseElementsAttr.get(
195                zero, type=IntegerType.get_signless(1), shape=(4, 2, 5)
196            )
197        )
198        # CHECK: dense<true> : tensor<4x2x5xi1>
199        print(
200            DenseElementsAttr.get(
201                one, type=IntegerType.get_signless(1), shape=(4, 2, 5)
202            )
203        )
204
205
206### float and double arrays.
207
208# CHECK-LABEL: TEST: testGetDenseElementsF16
209@run
210def testGetDenseElementsF16():
211    with Context():
212        array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16)
213        attr = DenseElementsAttr.get(array)
214        # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16>
215        print(attr)
216        # CHECK: {{\[}}[ 2. 4. 8.]
217        # CHECK: {{\[}}16. 32. 64.]]
218        print(np.array(attr))
219
220
221# CHECK-LABEL: TEST: testGetDenseElementsF32
222@run
223def testGetDenseElementsF32():
224    with Context():
225        array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
226        attr = DenseElementsAttr.get(array)
227        # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
228        print(attr)
229        # CHECK: {{\[}}[1.1 2.2 3.3]
230        # CHECK: {{\[}}4.4 5.5 6.6]]
231        print(np.array(attr))
232
233
234# CHECK-LABEL: TEST: testGetDenseElementsF64
235@run
236def testGetDenseElementsF64():
237    with Context():
238        array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
239        attr = DenseElementsAttr.get(array)
240        # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
241        print(attr)
242        # CHECK: {{\[}}[1.1 2.2 3.3]
243        # CHECK: {{\[}}4.4 5.5 6.6]]
244        print(np.array(attr))
245
246
247### 16 bit integer arrays
248# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
249@run
250def testGetDenseElementsI16Signless():
251    with Context():
252        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
253        attr = DenseElementsAttr.get(array)
254        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
255        print(attr)
256        # CHECK: {{\[}}[1 2 3]
257        # CHECK: {{\[}}4 5 6]]
258        print(np.array(attr))
259
260
261# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless
262@run
263def testGetDenseElementsUI16Signless():
264    with Context():
265        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
266        attr = DenseElementsAttr.get(array)
267        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
268        print(attr)
269        # CHECK: {{\[}}[1 2 3]
270        # CHECK: {{\[}}4 5 6]]
271        print(np.array(attr))
272
273
274# CHECK-LABEL: TEST: testGetDenseElementsI16
275@run
276def testGetDenseElementsI16():
277    with Context():
278        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
279        attr = DenseElementsAttr.get(array, signless=False)
280        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16>
281        print(attr)
282        # CHECK: {{\[}}[1 2 3]
283        # CHECK: {{\[}}4 5 6]]
284        print(np.array(attr))
285
286
287# CHECK-LABEL: TEST: testGetDenseElementsUI16
288@run
289def testGetDenseElementsUI16():
290    with Context():
291        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
292        attr = DenseElementsAttr.get(array, signless=False)
293        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16>
294        print(attr)
295        # CHECK: {{\[}}[1 2 3]
296        # CHECK: {{\[}}4 5 6]]
297        print(np.array(attr))
298
299
300### 32 bit integer arrays
301# CHECK-LABEL: TEST: testGetDenseElementsI32Signless
302@run
303def testGetDenseElementsI32Signless():
304    with Context():
305        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
306        attr = DenseElementsAttr.get(array)
307        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
308        print(attr)
309        # CHECK: {{\[}}[1 2 3]
310        # CHECK: {{\[}}4 5 6]]
311        print(np.array(attr))
312
313
314# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
315@run
316def testGetDenseElementsUI32Signless():
317    with Context():
318        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
319        attr = DenseElementsAttr.get(array)
320        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
321        print(attr)
322        # CHECK: {{\[}}[1 2 3]
323        # CHECK: {{\[}}4 5 6]]
324        print(np.array(attr))
325
326
327# CHECK-LABEL: TEST: testGetDenseElementsI32
328@run
329def testGetDenseElementsI32():
330    with Context():
331        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
332        attr = DenseElementsAttr.get(array, signless=False)
333        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
334        print(attr)
335        # CHECK: {{\[}}[1 2 3]
336        # CHECK: {{\[}}4 5 6]]
337        print(np.array(attr))
338
339
340# CHECK-LABEL: TEST: testGetDenseElementsUI32
341@run
342def testGetDenseElementsUI32():
343    with Context():
344        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
345        attr = DenseElementsAttr.get(array, signless=False)
346        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
347        print(attr)
348        # CHECK: {{\[}}[1 2 3]
349        # CHECK: {{\[}}4 5 6]]
350        print(np.array(attr))
351
352
353## 64bit integer arrays
354# CHECK-LABEL: TEST: testGetDenseElementsI64Signless
355@run
356def testGetDenseElementsI64Signless():
357    with Context():
358        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
359        attr = DenseElementsAttr.get(array)
360        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
361        print(attr)
362        # CHECK: {{\[}}[1 2 3]
363        # CHECK: {{\[}}4 5 6]]
364        print(np.array(attr))
365
366
367# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
368@run
369def testGetDenseElementsUI64Signless():
370    with Context():
371        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
372        attr = DenseElementsAttr.get(array)
373        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
374        print(attr)
375        # CHECK: {{\[}}[1 2 3]
376        # CHECK: {{\[}}4 5 6]]
377        print(np.array(attr))
378
379
380# CHECK-LABEL: TEST: testGetDenseElementsI64
381@run
382def testGetDenseElementsI64():
383    with Context():
384        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
385        attr = DenseElementsAttr.get(array, signless=False)
386        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
387        print(attr)
388        # CHECK: {{\[}}[1 2 3]
389        # CHECK: {{\[}}4 5 6]]
390        print(np.array(attr))
391
392
393# CHECK-LABEL: TEST: testGetDenseElementsUI64
394@run
395def testGetDenseElementsUI64():
396    with Context():
397        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
398        attr = DenseElementsAttr.get(array, signless=False)
399        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
400        print(attr)
401        # CHECK: {{\[}}[1 2 3]
402        # CHECK: {{\[}}4 5 6]]
403        print(np.array(attr))
404
405
406# CHECK-LABEL: TEST: testGetDenseElementsIndex
407@run
408def testGetDenseElementsIndex():
409    with Context(), Location.unknown():
410        idx_type = IndexType.get()
411        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
412        attr = DenseElementsAttr.get(array, type=idx_type)
413        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex>
414        print(attr)
415        arr = np.array(attr)
416        # CHECK: {{\[}}[1 2 3]
417        # CHECK: {{\[}}4 5 6]]
418        print(arr)
419        # CHECK: True
420        print(arr.dtype == np.int64)
421
422
423# CHECK-LABEL: TEST: testGetDenseResourceElementsAttr
424@run
425def testGetDenseResourceElementsAttr():
426    def on_delete(_):
427        print("BACKING MEMORY DELETED")
428
429    context = Context()
430    mview = memoryview(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
431    ref = weakref.ref(mview, on_delete)
432
433    def test_attribute(context, mview):
434        with context, Location.unknown():
435            element_type = IntegerType.get_signless(32)
436            tensor_type = RankedTensorType.get((2, 3), element_type)
437            resource = DenseResourceElementsAttr.get_from_buffer(
438                mview, "from_py", tensor_type
439            )
440            module = Module.parse("module {}")
441            module.operation.attributes["test.resource"] = resource
442            # CHECK: test.resource = dense_resource<from_py> : tensor<2x3xi32>
443            # CHECK: from_py: "0x04000000010000000200000003000000040000000500000006000000"
444            print(module)
445
446            # Verifies type casting.
447            # CHECK: dense_resource<from_py> : tensor<2x3xi32>
448            print(
449                DenseResourceElementsAttr(module.operation.attributes["test.resource"])
450            )
451
452    test_attribute(context, mview)
453    mview = None
454    gc.collect()
455    # CHECK: FREEING CONTEXT
456    print("FREEING CONTEXT")
457    context = None
458    gc.collect()
459    # CHECK: BACKING MEMORY DELETED
460    # CHECK: EXIT FUNCTION
461    print("EXIT FUNCTION")
462