xref: /llvm-project/mlir/test/python/ir/array_attributes.py (revision 0a68171b3c67503f7143856580f1b22a93ef566e)
1# RUN: %PYTHON %s | FileCheck %s
2# Note that this is separate from ir_attributes.py since it depends on numpy,
3# and we may want to disable if not available.
4
5import gc
6from mlir.ir import *
7import numpy as np
8import weakref
9
10
11def run(f):
12    print("\nTEST:", f.__name__)
13    f()
14    gc.collect()
15    assert Context._get_live_count() == 0
16    return f
17
18
19################################################################################
20# Tests of the array/buffer .get() factory method on unsupported dtype.
21################################################################################
22
23
24@run
25def testGetDenseElementsUnsupported():
26    with Context():
27        array = np.array([["hello", "goodbye"]])
28        try:
29            attr = DenseElementsAttr.get(array)
30        except ValueError as e:
31            # CHECK: unimplemented array format conversion from format:
32            print(e)
33
34# CHECK-LABEL: TEST: testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided
35@run
36def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
37    with Context():
38        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
39        # datetime64 specifically isn't important: it's just a 64-bit type that
40        # doesn't have a format under the Python buffer protocol. A more
41        # realistic example would be a NumPy extension type like the bfloat16
42        # type from the ml_dtypes package, which isn't a dependency of this
43        # test.
44        attr = DenseElementsAttr.get(array.view(np.datetime64),
45                                     type=IntegerType.get_signless(64))
46        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
47        print(attr)
48        # CHECK: {{\[}}[1 2 3]
49        # CHECK: {{\[}}4 5 6]]
50        print(np.array(attr))
51
52
53################################################################################
54# Tests of the list of attributes .get() factory method
55################################################################################
56
57
58# CHECK-LABEL: TEST: testGetDenseElementsFromList
59@run
60def testGetDenseElementsFromList():
61    with Context(), Location.unknown():
62        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
63        attr = DenseElementsAttr.get(attrs)
64
65        # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
66        print(attr)
67
68
69# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType
70@run
71def testGetDenseElementsFromListWithExplicitType():
72    with Context(), Location.unknown():
73        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
74        shaped_type = ShapedType(Type.parse("tensor<2xf64>"))
75        attr = DenseElementsAttr.get(attrs, shaped_type)
76
77        # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
78        print(attr)
79
80
81# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList
82@run
83def testGetDenseElementsFromListEmptyList():
84    with Context(), Location.unknown():
85        attrs = []
86
87        try:
88            attr = DenseElementsAttr.get(attrs)
89        except ValueError as e:
90            # CHECK: Attributes list must be non-empty
91            print(e)
92
93
94# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType
95@run
96def testGetDenseElementsFromListNonAttributeType():
97    with Context(), Location.unknown():
98        attrs = [1.0]
99
100        try:
101            attr = DenseElementsAttr.get(attrs)
102        except RuntimeError as e:
103            # CHECK: Invalid attribute when attempting to create an ArrayAttribute
104            print(e)
105
106
107# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType
108@run
109def testGetDenseElementsFromListMismatchedType():
110    with Context(), Location.unknown():
111        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
112        shaped_type = ShapedType(Type.parse("tensor<2xf32>"))
113
114        try:
115            attr = DenseElementsAttr.get(attrs, shaped_type)
116        except ValueError as e:
117            # CHECK: All attributes must be of the same type and match the type parameter
118            print(e)
119
120
121# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes
122@run
123def testGetDenseElementsFromListMixedTypes():
124    with Context(), Location.unknown():
125        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)]
126
127        try:
128            attr = DenseElementsAttr.get(attrs)
129        except ValueError as e:
130            # CHECK: All attributes must be of the same type and match the type parameter
131            print(e)
132
133
134################################################################################
135# Splats.
136################################################################################
137
138# CHECK-LABEL: TEST: testGetDenseElementsSplatInt
139@run
140def testGetDenseElementsSplatInt():
141    with Context(), Location.unknown():
142        t = IntegerType.get_signless(32)
143        element = IntegerAttr.get(t, 555)
144        shaped_type = RankedTensorType.get((2, 3, 4), t)
145        attr = DenseElementsAttr.get_splat(shaped_type, element)
146        # CHECK: dense<555> : tensor<2x3x4xi32>
147        print(attr)
148        # CHECK: is_splat: True
149        print("is_splat:", attr.is_splat)
150
151        # CHECK: splat_value: IntegerAttr(555 : i32)
152        splat_value = attr.get_splat_value()
153        print("splat_value:", repr(splat_value))
154        assert splat_value == element
155
156
157# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
158@run
159def testGetDenseElementsSplatFloat():
160    with Context(), Location.unknown():
161        t = F32Type.get()
162        element = FloatAttr.get(t, 1.2)
163        shaped_type = RankedTensorType.get((2, 3, 4), t)
164        attr = DenseElementsAttr.get_splat(shaped_type, element)
165        # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
166        print(attr)
167        assert attr.get_splat_value() == element
168
169
170# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
171@run
172def testGetDenseElementsSplatErrors():
173    with Context(), Location.unknown():
174        t = F32Type.get()
175        other_t = F64Type.get()
176        element = FloatAttr.get(t, 1.2)
177        other_element = FloatAttr.get(other_t, 1.2)
178        shaped_type = RankedTensorType.get((2, 3, 4), t)
179        dynamic_shaped_type = UnrankedTensorType.get(t)
180        non_shaped_type = t
181
182        try:
183            attr = DenseElementsAttr.get_splat(non_shaped_type, element)
184        except ValueError as e:
185            # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
186            print(e)
187
188        try:
189            attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element)
190        except ValueError as e:
191            # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
192            print(e)
193
194        try:
195            attr = DenseElementsAttr.get_splat(shaped_type, other_element)
196        except ValueError as e:
197            # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
198            print(e)
199
200
201# CHECK-LABEL: TEST: testRepeatedValuesSplat
202@run
203def testRepeatedValuesSplat():
204    with Context():
205        array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32)
206        attr = DenseElementsAttr.get(array)
207        # CHECK: dense<1.000000e+00> : tensor<2x3xf32>
208        print(attr)
209        # CHECK: is_splat: True
210        print("is_splat:", attr.is_splat)
211        # CHECK{LITERAL}: [[1. 1. 1.]
212        # CHECK{LITERAL}:  [1. 1. 1.]]
213        print(np.array(attr))
214
215
216# CHECK-LABEL: TEST: testNonSplat
217@run
218def testNonSplat():
219    with Context():
220        array = np.array([2.0, 1.0, 1.0], dtype=np.float32)
221        attr = DenseElementsAttr.get(array)
222        # CHECK: is_splat: False
223        print("is_splat:", attr.is_splat)
224
225
226################################################################################
227# Tests of the array/buffer .get() factory method, in all of its permutations.
228################################################################################
229
230### explicitly provided types
231
232
233@run
234def testGetDenseElementsBF16():
235    with Context():
236        array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16)
237        attr = DenseElementsAttr.get(array, type=BF16Type.get())
238        # Note: These values don't mean much since just bit-casting. But they
239        # shouldn't change.
240        # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16>
241        print(attr)
242
243
244@run
245def testGetDenseElementsInteger4():
246    with Context():
247        array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.int8)
248        attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
249        # Note: These values don't mean much since just bit-casting. But they
250        # shouldn't change.
251        # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4>
252        print(attr)
253
254
255@run
256def testGetDenseElementsBool():
257    with Context():
258        bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_)
259        array = np.packbits(bool_array, axis=None, bitorder="little")
260        attr = DenseElementsAttr.get(
261            array, type=IntegerType.get_signless(1), shape=bool_array.shape
262        )
263        # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1>
264        print(attr)
265
266
267@run
268def testGetDenseElementsBoolSplat():
269    with Context():
270        zero = np.array(0, dtype=np.uint8)
271        one = np.array(255, dtype=np.uint8)
272        print(one)
273        # CHECK: dense<false> : tensor<4x2x5xi1>
274        print(
275            DenseElementsAttr.get(
276                zero, type=IntegerType.get_signless(1), shape=(4, 2, 5)
277            )
278        )
279        # CHECK: dense<true> : tensor<4x2x5xi1>
280        print(
281            DenseElementsAttr.get(
282                one, type=IntegerType.get_signless(1), shape=(4, 2, 5)
283            )
284        )
285
286
287### float and double arrays.
288
289
290# CHECK-LABEL: TEST: testGetDenseElementsF16
291@run
292def testGetDenseElementsF16():
293    with Context():
294        array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16)
295        attr = DenseElementsAttr.get(array)
296        # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16>
297        print(attr)
298        # CHECK: {{\[}}[ 2. 4. 8.]
299        # CHECK: {{\[}}16. 32. 64.]]
300        print(np.array(attr))
301
302
303# CHECK-LABEL: TEST: testGetDenseElementsF32
304@run
305def testGetDenseElementsF32():
306    with Context():
307        array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
308        attr = DenseElementsAttr.get(array)
309        # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
310        print(attr)
311        # CHECK: {{\[}}[1.1 2.2 3.3]
312        # CHECK: {{\[}}4.4 5.5 6.6]]
313        print(np.array(attr))
314
315
316# CHECK-LABEL: TEST: testGetDenseElementsF64
317@run
318def testGetDenseElementsF64():
319    with Context():
320        array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
321        attr = DenseElementsAttr.get(array)
322        # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
323        print(attr)
324        # CHECK: {{\[}}[1.1 2.2 3.3]
325        # CHECK: {{\[}}4.4 5.5 6.6]]
326        print(np.array(attr))
327
328
329### 16 bit integer arrays
330# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
331@run
332def testGetDenseElementsI16Signless():
333    with Context():
334        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
335        attr = DenseElementsAttr.get(array)
336        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
337        print(attr)
338        # CHECK: {{\[}}[1 2 3]
339        # CHECK: {{\[}}4 5 6]]
340        print(np.array(attr))
341
342
343# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless
344@run
345def testGetDenseElementsUI16Signless():
346    with Context():
347        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
348        attr = DenseElementsAttr.get(array)
349        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
350        print(attr)
351        # CHECK: {{\[}}[1 2 3]
352        # CHECK: {{\[}}4 5 6]]
353        print(np.array(attr))
354
355
356# CHECK-LABEL: TEST: testGetDenseElementsI16
357@run
358def testGetDenseElementsI16():
359    with Context():
360        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
361        attr = DenseElementsAttr.get(array, signless=False)
362        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16>
363        print(attr)
364        # CHECK: {{\[}}[1 2 3]
365        # CHECK: {{\[}}4 5 6]]
366        print(np.array(attr))
367
368
369# CHECK-LABEL: TEST: testGetDenseElementsUI16
370@run
371def testGetDenseElementsUI16():
372    with Context():
373        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
374        attr = DenseElementsAttr.get(array, signless=False)
375        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16>
376        print(attr)
377        # CHECK: {{\[}}[1 2 3]
378        # CHECK: {{\[}}4 5 6]]
379        print(np.array(attr))
380
381
382### 32 bit integer arrays
383# CHECK-LABEL: TEST: testGetDenseElementsI32Signless
384@run
385def testGetDenseElementsI32Signless():
386    with Context():
387        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
388        attr = DenseElementsAttr.get(array)
389        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
390        print(attr)
391        # CHECK: {{\[}}[1 2 3]
392        # CHECK: {{\[}}4 5 6]]
393        print(np.array(attr))
394
395
396# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
397@run
398def testGetDenseElementsUI32Signless():
399    with Context():
400        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
401        attr = DenseElementsAttr.get(array)
402        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
403        print(attr)
404        # CHECK: {{\[}}[1 2 3]
405        # CHECK: {{\[}}4 5 6]]
406        print(np.array(attr))
407
408
409# CHECK-LABEL: TEST: testGetDenseElementsI32
410@run
411def testGetDenseElementsI32():
412    with Context():
413        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
414        attr = DenseElementsAttr.get(array, signless=False)
415        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
416        print(attr)
417        # CHECK: {{\[}}[1 2 3]
418        # CHECK: {{\[}}4 5 6]]
419        print(np.array(attr))
420
421
422# CHECK-LABEL: TEST: testGetDenseElementsUI32
423@run
424def testGetDenseElementsUI32():
425    with Context():
426        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
427        attr = DenseElementsAttr.get(array, signless=False)
428        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
429        print(attr)
430        # CHECK: {{\[}}[1 2 3]
431        # CHECK: {{\[}}4 5 6]]
432        print(np.array(attr))
433
434
435## 64bit integer arrays
436# CHECK-LABEL: TEST: testGetDenseElementsI64Signless
437@run
438def testGetDenseElementsI64Signless():
439    with Context():
440        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
441        attr = DenseElementsAttr.get(array)
442        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
443        print(attr)
444        # CHECK: {{\[}}[1 2 3]
445        # CHECK: {{\[}}4 5 6]]
446        print(np.array(attr))
447
448
449# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
450@run
451def testGetDenseElementsUI64Signless():
452    with Context():
453        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
454        attr = DenseElementsAttr.get(array)
455        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
456        print(attr)
457        # CHECK: {{\[}}[1 2 3]
458        # CHECK: {{\[}}4 5 6]]
459        print(np.array(attr))
460
461
462# CHECK-LABEL: TEST: testGetDenseElementsI64
463@run
464def testGetDenseElementsI64():
465    with Context():
466        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
467        attr = DenseElementsAttr.get(array, signless=False)
468        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
469        print(attr)
470        # CHECK: {{\[}}[1 2 3]
471        # CHECK: {{\[}}4 5 6]]
472        print(np.array(attr))
473
474
475# CHECK-LABEL: TEST: testGetDenseElementsUI64
476@run
477def testGetDenseElementsUI64():
478    with Context():
479        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
480        attr = DenseElementsAttr.get(array, signless=False)
481        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
482        print(attr)
483        # CHECK: {{\[}}[1 2 3]
484        # CHECK: {{\[}}4 5 6]]
485        print(np.array(attr))
486
487
488# CHECK-LABEL: TEST: testGetDenseElementsIndex
489@run
490def testGetDenseElementsIndex():
491    with Context(), Location.unknown():
492        idx_type = IndexType.get()
493        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
494        attr = DenseElementsAttr.get(array, type=idx_type)
495        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex>
496        print(attr)
497        arr = np.array(attr)
498        # CHECK: {{\[}}[1 2 3]
499        # CHECK: {{\[}}4 5 6]]
500        print(arr)
501        # CHECK: True
502        print(arr.dtype == np.int64)
503
504
505# CHECK-LABEL: TEST: testGetDenseResourceElementsAttr
506@run
507def testGetDenseResourceElementsAttr():
508    def on_delete(_):
509        print("BACKING MEMORY DELETED")
510
511    context = Context()
512    mview = memoryview(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
513    ref = weakref.ref(mview, on_delete)
514
515    def test_attribute(context, mview):
516        with context, Location.unknown():
517            element_type = IntegerType.get_signless(32)
518            tensor_type = RankedTensorType.get((2, 3), element_type)
519            resource = DenseResourceElementsAttr.get_from_buffer(
520                mview, "from_py", tensor_type
521            )
522            module = Module.parse("module {}")
523            module.operation.attributes["test.resource"] = resource
524            # CHECK: test.resource = dense_resource<from_py> : tensor<2x3xi32>
525            # CHECK: from_py: "0x04000000010000000200000003000000040000000500000006000000"
526            print(module)
527
528            # Verifies type casting.
529            # CHECK: dense_resource<from_py> : tensor<2x3xi32>
530            print(
531                DenseResourceElementsAttr(module.operation.attributes["test.resource"])
532            )
533
534    test_attribute(context, mview)
535    mview = None
536    gc.collect()
537    # CHECK: FREEING CONTEXT
538    print("FREEING CONTEXT")
539    context = None
540    gc.collect()
541    # CHECK: BACKING MEMORY DELETED
542    # CHECK: EXIT FUNCTION
543    print("EXIT FUNCTION")
544