1# RUN: %PYTHON %s | FileCheck %s 2# Note that this is separate from ir_attributes.py since it depends on numpy, 3# and we may want to disable if not available. 4 5import gc 6from mlir.ir import * 7import numpy as np 8import weakref 9 10 11def run(f): 12 print("\nTEST:", f.__name__) 13 f() 14 gc.collect() 15 assert Context._get_live_count() == 0 16 return f 17 18 19################################################################################ 20# Tests of the array/buffer .get() factory method on unsupported dtype. 21################################################################################ 22 23 24@run 25def testGetDenseElementsUnsupported(): 26 with Context(): 27 array = np.array([["hello", "goodbye"]]) 28 try: 29 attr = DenseElementsAttr.get(array) 30 except ValueError as e: 31 # CHECK: unimplemented array format conversion from format: 32 print(e) 33 34# CHECK-LABEL: TEST: testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided 35@run 36def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided(): 37 with Context(): 38 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 39 # datetime64 specifically isn't important: it's just a 64-bit type that 40 # doesn't have a format under the Python buffer protocol. A more 41 # realistic example would be a NumPy extension type like the bfloat16 42 # type from the ml_dtypes package, which isn't a dependency of this 43 # test. 44 attr = DenseElementsAttr.get(array.view(np.datetime64), 45 type=IntegerType.get_signless(64)) 46 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> 47 print(attr) 48 # CHECK: {{\[}}[1 2 3] 49 # CHECK: {{\[}}4 5 6]] 50 print(np.array(attr)) 51 52 53################################################################################ 54# Splats. 55################################################################################ 56 57# CHECK-LABEL: TEST: testGetDenseElementsSplatInt 58@run 59def testGetDenseElementsSplatInt(): 60 with Context(), Location.unknown(): 61 t = IntegerType.get_signless(32) 62 element = IntegerAttr.get(t, 555) 63 shaped_type = RankedTensorType.get((2, 3, 4), t) 64 attr = DenseElementsAttr.get_splat(shaped_type, element) 65 # CHECK: dense<555> : tensor<2x3x4xi32> 66 print(attr) 67 # CHECK: is_splat: True 68 print("is_splat:", attr.is_splat) 69 70 # CHECK: splat_value: IntegerAttr(555 : i32) 71 splat_value = attr.get_splat_value() 72 print("splat_value:", repr(splat_value)) 73 assert splat_value == element 74 75 76# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat 77@run 78def testGetDenseElementsSplatFloat(): 79 with Context(), Location.unknown(): 80 t = F32Type.get() 81 element = FloatAttr.get(t, 1.2) 82 shaped_type = RankedTensorType.get((2, 3, 4), t) 83 attr = DenseElementsAttr.get_splat(shaped_type, element) 84 # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32> 85 print(attr) 86 assert attr.get_splat_value() == element 87 88 89# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors 90@run 91def testGetDenseElementsSplatErrors(): 92 with Context(), Location.unknown(): 93 t = F32Type.get() 94 other_t = F64Type.get() 95 element = FloatAttr.get(t, 1.2) 96 other_element = FloatAttr.get(other_t, 1.2) 97 shaped_type = RankedTensorType.get((2, 3, 4), t) 98 dynamic_shaped_type = UnrankedTensorType.get(t) 99 non_shaped_type = t 100 101 try: 102 attr = DenseElementsAttr.get_splat(non_shaped_type, element) 103 except ValueError as e: 104 # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32) 105 print(e) 106 107 try: 108 attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element) 109 except ValueError as e: 110 # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>) 111 print(e) 112 113 try: 114 attr = DenseElementsAttr.get_splat(shaped_type, other_element) 115 except ValueError as e: 116 # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64) 117 print(e) 118 119 120# CHECK-LABEL: TEST: testRepeatedValuesSplat 121@run 122def testRepeatedValuesSplat(): 123 with Context(): 124 array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32) 125 attr = DenseElementsAttr.get(array) 126 # CHECK: dense<1.000000e+00> : tensor<2x3xf32> 127 print(attr) 128 # CHECK: is_splat: True 129 print("is_splat:", attr.is_splat) 130 # CHECK{LITERAL}: [[1. 1. 1.] 131 # CHECK{LITERAL}: [1. 1. 1.]] 132 print(np.array(attr)) 133 134 135# CHECK-LABEL: TEST: testNonSplat 136@run 137def testNonSplat(): 138 with Context(): 139 array = np.array([2.0, 1.0, 1.0], dtype=np.float32) 140 attr = DenseElementsAttr.get(array) 141 # CHECK: is_splat: False 142 print("is_splat:", attr.is_splat) 143 144 145################################################################################ 146# Tests of the array/buffer .get() factory method, in all of its permutations. 147################################################################################ 148 149### explicitly provided types 150 151 152@run 153def testGetDenseElementsBF16(): 154 with Context(): 155 array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16) 156 attr = DenseElementsAttr.get(array, type=BF16Type.get()) 157 # Note: These values don't mean much since just bit-casting. But they 158 # shouldn't change. 159 # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16> 160 print(attr) 161 162 163@run 164def testGetDenseElementsInteger4(): 165 with Context(): 166 array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.int8) 167 attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4)) 168 # Note: These values don't mean much since just bit-casting. But they 169 # shouldn't change. 170 # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4> 171 print(attr) 172 173 174@run 175def testGetDenseElementsBool(): 176 with Context(): 177 bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_) 178 array = np.packbits(bool_array, axis=None, bitorder="little") 179 attr = DenseElementsAttr.get( 180 array, type=IntegerType.get_signless(1), shape=bool_array.shape 181 ) 182 # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1> 183 print(attr) 184 185 186@run 187def testGetDenseElementsBoolSplat(): 188 with Context(): 189 zero = np.array(0, dtype=np.uint8) 190 one = np.array(255, dtype=np.uint8) 191 print(one) 192 # CHECK: dense<false> : tensor<4x2x5xi1> 193 print( 194 DenseElementsAttr.get( 195 zero, type=IntegerType.get_signless(1), shape=(4, 2, 5) 196 ) 197 ) 198 # CHECK: dense<true> : tensor<4x2x5xi1> 199 print( 200 DenseElementsAttr.get( 201 one, type=IntegerType.get_signless(1), shape=(4, 2, 5) 202 ) 203 ) 204 205 206### float and double arrays. 207 208# CHECK-LABEL: TEST: testGetDenseElementsF16 209@run 210def testGetDenseElementsF16(): 211 with Context(): 212 array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16) 213 attr = DenseElementsAttr.get(array) 214 # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16> 215 print(attr) 216 # CHECK: {{\[}}[ 2. 4. 8.] 217 # CHECK: {{\[}}16. 32. 64.]] 218 print(np.array(attr)) 219 220 221# CHECK-LABEL: TEST: testGetDenseElementsF32 222@run 223def testGetDenseElementsF32(): 224 with Context(): 225 array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32) 226 attr = DenseElementsAttr.get(array) 227 # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32> 228 print(attr) 229 # CHECK: {{\[}}[1.1 2.2 3.3] 230 # CHECK: {{\[}}4.4 5.5 6.6]] 231 print(np.array(attr)) 232 233 234# CHECK-LABEL: TEST: testGetDenseElementsF64 235@run 236def testGetDenseElementsF64(): 237 with Context(): 238 array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64) 239 attr = DenseElementsAttr.get(array) 240 # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64> 241 print(attr) 242 # CHECK: {{\[}}[1.1 2.2 3.3] 243 # CHECK: {{\[}}4.4 5.5 6.6]] 244 print(np.array(attr)) 245 246 247### 16 bit integer arrays 248# CHECK-LABEL: TEST: testGetDenseElementsI16Signless 249@run 250def testGetDenseElementsI16Signless(): 251 with Context(): 252 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) 253 attr = DenseElementsAttr.get(array) 254 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> 255 print(attr) 256 # CHECK: {{\[}}[1 2 3] 257 # CHECK: {{\[}}4 5 6]] 258 print(np.array(attr)) 259 260 261# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless 262@run 263def testGetDenseElementsUI16Signless(): 264 with Context(): 265 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) 266 attr = DenseElementsAttr.get(array) 267 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> 268 print(attr) 269 # CHECK: {{\[}}[1 2 3] 270 # CHECK: {{\[}}4 5 6]] 271 print(np.array(attr)) 272 273 274# CHECK-LABEL: TEST: testGetDenseElementsI16 275@run 276def testGetDenseElementsI16(): 277 with Context(): 278 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) 279 attr = DenseElementsAttr.get(array, signless=False) 280 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16> 281 print(attr) 282 # CHECK: {{\[}}[1 2 3] 283 # CHECK: {{\[}}4 5 6]] 284 print(np.array(attr)) 285 286 287# CHECK-LABEL: TEST: testGetDenseElementsUI16 288@run 289def testGetDenseElementsUI16(): 290 with Context(): 291 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) 292 attr = DenseElementsAttr.get(array, signless=False) 293 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16> 294 print(attr) 295 # CHECK: {{\[}}[1 2 3] 296 # CHECK: {{\[}}4 5 6]] 297 print(np.array(attr)) 298 299 300### 32 bit integer arrays 301# CHECK-LABEL: TEST: testGetDenseElementsI32Signless 302@run 303def testGetDenseElementsI32Signless(): 304 with Context(): 305 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) 306 attr = DenseElementsAttr.get(array) 307 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 308 print(attr) 309 # CHECK: {{\[}}[1 2 3] 310 # CHECK: {{\[}}4 5 6]] 311 print(np.array(attr)) 312 313 314# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless 315@run 316def testGetDenseElementsUI32Signless(): 317 with Context(): 318 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) 319 attr = DenseElementsAttr.get(array) 320 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 321 print(attr) 322 # CHECK: {{\[}}[1 2 3] 323 # CHECK: {{\[}}4 5 6]] 324 print(np.array(attr)) 325 326 327# CHECK-LABEL: TEST: testGetDenseElementsI32 328@run 329def testGetDenseElementsI32(): 330 with Context(): 331 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) 332 attr = DenseElementsAttr.get(array, signless=False) 333 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32> 334 print(attr) 335 # CHECK: {{\[}}[1 2 3] 336 # CHECK: {{\[}}4 5 6]] 337 print(np.array(attr)) 338 339 340# CHECK-LABEL: TEST: testGetDenseElementsUI32 341@run 342def testGetDenseElementsUI32(): 343 with Context(): 344 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) 345 attr = DenseElementsAttr.get(array, signless=False) 346 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32> 347 print(attr) 348 # CHECK: {{\[}}[1 2 3] 349 # CHECK: {{\[}}4 5 6]] 350 print(np.array(attr)) 351 352 353## 64bit integer arrays 354# CHECK-LABEL: TEST: testGetDenseElementsI64Signless 355@run 356def testGetDenseElementsI64Signless(): 357 with Context(): 358 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 359 attr = DenseElementsAttr.get(array) 360 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> 361 print(attr) 362 # CHECK: {{\[}}[1 2 3] 363 # CHECK: {{\[}}4 5 6]] 364 print(np.array(attr)) 365 366 367# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless 368@run 369def testGetDenseElementsUI64Signless(): 370 with Context(): 371 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) 372 attr = DenseElementsAttr.get(array) 373 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> 374 print(attr) 375 # CHECK: {{\[}}[1 2 3] 376 # CHECK: {{\[}}4 5 6]] 377 print(np.array(attr)) 378 379 380# CHECK-LABEL: TEST: testGetDenseElementsI64 381@run 382def testGetDenseElementsI64(): 383 with Context(): 384 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 385 attr = DenseElementsAttr.get(array, signless=False) 386 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64> 387 print(attr) 388 # CHECK: {{\[}}[1 2 3] 389 # CHECK: {{\[}}4 5 6]] 390 print(np.array(attr)) 391 392 393# CHECK-LABEL: TEST: testGetDenseElementsUI64 394@run 395def testGetDenseElementsUI64(): 396 with Context(): 397 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) 398 attr = DenseElementsAttr.get(array, signless=False) 399 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64> 400 print(attr) 401 # CHECK: {{\[}}[1 2 3] 402 # CHECK: {{\[}}4 5 6]] 403 print(np.array(attr)) 404 405 406# CHECK-LABEL: TEST: testGetDenseElementsIndex 407@run 408def testGetDenseElementsIndex(): 409 with Context(), Location.unknown(): 410 idx_type = IndexType.get() 411 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 412 attr = DenseElementsAttr.get(array, type=idx_type) 413 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex> 414 print(attr) 415 arr = np.array(attr) 416 # CHECK: {{\[}}[1 2 3] 417 # CHECK: {{\[}}4 5 6]] 418 print(arr) 419 # CHECK: True 420 print(arr.dtype == np.int64) 421 422 423# CHECK-LABEL: TEST: testGetDenseResourceElementsAttr 424@run 425def testGetDenseResourceElementsAttr(): 426 def on_delete(_): 427 print("BACKING MEMORY DELETED") 428 429 context = Context() 430 mview = memoryview(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) 431 ref = weakref.ref(mview, on_delete) 432 433 def test_attribute(context, mview): 434 with context, Location.unknown(): 435 element_type = IntegerType.get_signless(32) 436 tensor_type = RankedTensorType.get((2, 3), element_type) 437 resource = DenseResourceElementsAttr.get_from_buffer( 438 mview, "from_py", tensor_type 439 ) 440 module = Module.parse("module {}") 441 module.operation.attributes["test.resource"] = resource 442 # CHECK: test.resource = dense_resource<from_py> : tensor<2x3xi32> 443 # CHECK: from_py: "0x04000000010000000200000003000000040000000500000006000000" 444 print(module) 445 446 # Verifies type casting. 447 # CHECK: dense_resource<from_py> : tensor<2x3xi32> 448 print( 449 DenseResourceElementsAttr(module.operation.attributes["test.resource"]) 450 ) 451 452 test_attribute(context, mview) 453 mview = None 454 gc.collect() 455 # CHECK: FREEING CONTEXT 456 print("FREEING CONTEXT") 457 context = None 458 gc.collect() 459 # CHECK: BACKING MEMORY DELETED 460 # CHECK: EXIT FUNCTION 461 print("EXIT FUNCTION") 462