1# RUN: %PYTHON %s | FileCheck %s 2# Note that this is separate from ir_attributes.py since it depends on numpy, 3# and we may want to disable if not available. 4 5import gc 6from mlir.ir import * 7import numpy as np 8import weakref 9 10 11def run(f): 12 print("\nTEST:", f.__name__) 13 f() 14 gc.collect() 15 assert Context._get_live_count() == 0 16 return f 17 18 19################################################################################ 20# Tests of the array/buffer .get() factory method on unsupported dtype. 21################################################################################ 22 23 24@run 25def testGetDenseElementsUnsupported(): 26 with Context(): 27 array = np.array([["hello", "goodbye"]]) 28 try: 29 attr = DenseElementsAttr.get(array) 30 except ValueError as e: 31 # CHECK: unimplemented array format conversion from format: 32 print(e) 33 34# CHECK-LABEL: TEST: testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided 35@run 36def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided(): 37 with Context(): 38 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 39 # datetime64 specifically isn't important: it's just a 64-bit type that 40 # doesn't have a format under the Python buffer protocol. A more 41 # realistic example would be a NumPy extension type like the bfloat16 42 # type from the ml_dtypes package, which isn't a dependency of this 43 # test. 44 attr = DenseElementsAttr.get(array.view(np.datetime64), 45 type=IntegerType.get_signless(64)) 46 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> 47 print(attr) 48 # CHECK: {{\[}}[1 2 3] 49 # CHECK: {{\[}}4 5 6]] 50 print(np.array(attr)) 51 52 53################################################################################ 54# Tests of the list of attributes .get() factory method 55################################################################################ 56 57 58# CHECK-LABEL: TEST: testGetDenseElementsFromList 59@run 60def testGetDenseElementsFromList(): 61 with Context(), Location.unknown(): 62 attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)] 63 attr = DenseElementsAttr.get(attrs) 64 65 # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64> 66 print(attr) 67 68 69# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType 70@run 71def testGetDenseElementsFromListWithExplicitType(): 72 with Context(), Location.unknown(): 73 attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)] 74 shaped_type = ShapedType(Type.parse("tensor<2xf64>")) 75 attr = DenseElementsAttr.get(attrs, shaped_type) 76 77 # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64> 78 print(attr) 79 80 81# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList 82@run 83def testGetDenseElementsFromListEmptyList(): 84 with Context(), Location.unknown(): 85 attrs = [] 86 87 try: 88 attr = DenseElementsAttr.get(attrs) 89 except ValueError as e: 90 # CHECK: Attributes list must be non-empty 91 print(e) 92 93 94# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType 95@run 96def testGetDenseElementsFromListNonAttributeType(): 97 with Context(), Location.unknown(): 98 attrs = [1.0] 99 100 try: 101 attr = DenseElementsAttr.get(attrs) 102 except RuntimeError as e: 103 # CHECK: Invalid attribute when attempting to create an ArrayAttribute 104 print(e) 105 106 107# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType 108@run 109def testGetDenseElementsFromListMismatchedType(): 110 with Context(), Location.unknown(): 111 attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)] 112 shaped_type = ShapedType(Type.parse("tensor<2xf32>")) 113 114 try: 115 attr = DenseElementsAttr.get(attrs, shaped_type) 116 except ValueError as e: 117 # CHECK: All attributes must be of the same type and match the type parameter 118 print(e) 119 120 121# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes 122@run 123def testGetDenseElementsFromListMixedTypes(): 124 with Context(), Location.unknown(): 125 attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)] 126 127 try: 128 attr = DenseElementsAttr.get(attrs) 129 except ValueError as e: 130 # CHECK: All attributes must be of the same type and match the type parameter 131 print(e) 132 133 134################################################################################ 135# Splats. 136################################################################################ 137 138# CHECK-LABEL: TEST: testGetDenseElementsSplatInt 139@run 140def testGetDenseElementsSplatInt(): 141 with Context(), Location.unknown(): 142 t = IntegerType.get_signless(32) 143 element = IntegerAttr.get(t, 555) 144 shaped_type = RankedTensorType.get((2, 3, 4), t) 145 attr = DenseElementsAttr.get_splat(shaped_type, element) 146 # CHECK: dense<555> : tensor<2x3x4xi32> 147 print(attr) 148 # CHECK: is_splat: True 149 print("is_splat:", attr.is_splat) 150 151 # CHECK: splat_value: IntegerAttr(555 : i32) 152 splat_value = attr.get_splat_value() 153 print("splat_value:", repr(splat_value)) 154 assert splat_value == element 155 156 157# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat 158@run 159def testGetDenseElementsSplatFloat(): 160 with Context(), Location.unknown(): 161 t = F32Type.get() 162 element = FloatAttr.get(t, 1.2) 163 shaped_type = RankedTensorType.get((2, 3, 4), t) 164 attr = DenseElementsAttr.get_splat(shaped_type, element) 165 # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32> 166 print(attr) 167 assert attr.get_splat_value() == element 168 169 170# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors 171@run 172def testGetDenseElementsSplatErrors(): 173 with Context(), Location.unknown(): 174 t = F32Type.get() 175 other_t = F64Type.get() 176 element = FloatAttr.get(t, 1.2) 177 other_element = FloatAttr.get(other_t, 1.2) 178 shaped_type = RankedTensorType.get((2, 3, 4), t) 179 dynamic_shaped_type = UnrankedTensorType.get(t) 180 non_shaped_type = t 181 182 try: 183 attr = DenseElementsAttr.get_splat(non_shaped_type, element) 184 except ValueError as e: 185 # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32) 186 print(e) 187 188 try: 189 attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element) 190 except ValueError as e: 191 # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>) 192 print(e) 193 194 try: 195 attr = DenseElementsAttr.get_splat(shaped_type, other_element) 196 except ValueError as e: 197 # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64) 198 print(e) 199 200 201# CHECK-LABEL: TEST: testRepeatedValuesSplat 202@run 203def testRepeatedValuesSplat(): 204 with Context(): 205 array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32) 206 attr = DenseElementsAttr.get(array) 207 # CHECK: dense<1.000000e+00> : tensor<2x3xf32> 208 print(attr) 209 # CHECK: is_splat: True 210 print("is_splat:", attr.is_splat) 211 # CHECK{LITERAL}: [[1. 1. 1.] 212 # CHECK{LITERAL}: [1. 1. 1.]] 213 print(np.array(attr)) 214 215 216# CHECK-LABEL: TEST: testNonSplat 217@run 218def testNonSplat(): 219 with Context(): 220 array = np.array([2.0, 1.0, 1.0], dtype=np.float32) 221 attr = DenseElementsAttr.get(array) 222 # CHECK: is_splat: False 223 print("is_splat:", attr.is_splat) 224 225 226################################################################################ 227# Tests of the array/buffer .get() factory method, in all of its permutations. 228################################################################################ 229 230### explicitly provided types 231 232 233@run 234def testGetDenseElementsBF16(): 235 with Context(): 236 array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16) 237 attr = DenseElementsAttr.get(array, type=BF16Type.get()) 238 # Note: These values don't mean much since just bit-casting. But they 239 # shouldn't change. 240 # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16> 241 print(attr) 242 243 244@run 245def testGetDenseElementsInteger4(): 246 with Context(): 247 array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.int8) 248 attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4)) 249 # Note: These values don't mean much since just bit-casting. But they 250 # shouldn't change. 251 # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4> 252 print(attr) 253 254 255@run 256def testGetDenseElementsBool(): 257 with Context(): 258 bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_) 259 array = np.packbits(bool_array, axis=None, bitorder="little") 260 attr = DenseElementsAttr.get( 261 array, type=IntegerType.get_signless(1), shape=bool_array.shape 262 ) 263 # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1> 264 print(attr) 265 266 267@run 268def testGetDenseElementsBoolSplat(): 269 with Context(): 270 zero = np.array(0, dtype=np.uint8) 271 one = np.array(255, dtype=np.uint8) 272 print(one) 273 # CHECK: dense<false> : tensor<4x2x5xi1> 274 print( 275 DenseElementsAttr.get( 276 zero, type=IntegerType.get_signless(1), shape=(4, 2, 5) 277 ) 278 ) 279 # CHECK: dense<true> : tensor<4x2x5xi1> 280 print( 281 DenseElementsAttr.get( 282 one, type=IntegerType.get_signless(1), shape=(4, 2, 5) 283 ) 284 ) 285 286 287### float and double arrays. 288 289 290# CHECK-LABEL: TEST: testGetDenseElementsF16 291@run 292def testGetDenseElementsF16(): 293 with Context(): 294 array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16) 295 attr = DenseElementsAttr.get(array) 296 # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16> 297 print(attr) 298 # CHECK: {{\[}}[ 2. 4. 8.] 299 # CHECK: {{\[}}16. 32. 64.]] 300 print(np.array(attr)) 301 302 303# CHECK-LABEL: TEST: testGetDenseElementsF32 304@run 305def testGetDenseElementsF32(): 306 with Context(): 307 array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32) 308 attr = DenseElementsAttr.get(array) 309 # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32> 310 print(attr) 311 # CHECK: {{\[}}[1.1 2.2 3.3] 312 # CHECK: {{\[}}4.4 5.5 6.6]] 313 print(np.array(attr)) 314 315 316# CHECK-LABEL: TEST: testGetDenseElementsF64 317@run 318def testGetDenseElementsF64(): 319 with Context(): 320 array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64) 321 attr = DenseElementsAttr.get(array) 322 # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64> 323 print(attr) 324 # CHECK: {{\[}}[1.1 2.2 3.3] 325 # CHECK: {{\[}}4.4 5.5 6.6]] 326 print(np.array(attr)) 327 328 329### 16 bit integer arrays 330# CHECK-LABEL: TEST: testGetDenseElementsI16Signless 331@run 332def testGetDenseElementsI16Signless(): 333 with Context(): 334 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) 335 attr = DenseElementsAttr.get(array) 336 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> 337 print(attr) 338 # CHECK: {{\[}}[1 2 3] 339 # CHECK: {{\[}}4 5 6]] 340 print(np.array(attr)) 341 342 343# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless 344@run 345def testGetDenseElementsUI16Signless(): 346 with Context(): 347 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) 348 attr = DenseElementsAttr.get(array) 349 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> 350 print(attr) 351 # CHECK: {{\[}}[1 2 3] 352 # CHECK: {{\[}}4 5 6]] 353 print(np.array(attr)) 354 355 356# CHECK-LABEL: TEST: testGetDenseElementsI16 357@run 358def testGetDenseElementsI16(): 359 with Context(): 360 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) 361 attr = DenseElementsAttr.get(array, signless=False) 362 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16> 363 print(attr) 364 # CHECK: {{\[}}[1 2 3] 365 # CHECK: {{\[}}4 5 6]] 366 print(np.array(attr)) 367 368 369# CHECK-LABEL: TEST: testGetDenseElementsUI16 370@run 371def testGetDenseElementsUI16(): 372 with Context(): 373 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) 374 attr = DenseElementsAttr.get(array, signless=False) 375 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16> 376 print(attr) 377 # CHECK: {{\[}}[1 2 3] 378 # CHECK: {{\[}}4 5 6]] 379 print(np.array(attr)) 380 381 382### 32 bit integer arrays 383# CHECK-LABEL: TEST: testGetDenseElementsI32Signless 384@run 385def testGetDenseElementsI32Signless(): 386 with Context(): 387 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) 388 attr = DenseElementsAttr.get(array) 389 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 390 print(attr) 391 # CHECK: {{\[}}[1 2 3] 392 # CHECK: {{\[}}4 5 6]] 393 print(np.array(attr)) 394 395 396# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless 397@run 398def testGetDenseElementsUI32Signless(): 399 with Context(): 400 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) 401 attr = DenseElementsAttr.get(array) 402 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 403 print(attr) 404 # CHECK: {{\[}}[1 2 3] 405 # CHECK: {{\[}}4 5 6]] 406 print(np.array(attr)) 407 408 409# CHECK-LABEL: TEST: testGetDenseElementsI32 410@run 411def testGetDenseElementsI32(): 412 with Context(): 413 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) 414 attr = DenseElementsAttr.get(array, signless=False) 415 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32> 416 print(attr) 417 # CHECK: {{\[}}[1 2 3] 418 # CHECK: {{\[}}4 5 6]] 419 print(np.array(attr)) 420 421 422# CHECK-LABEL: TEST: testGetDenseElementsUI32 423@run 424def testGetDenseElementsUI32(): 425 with Context(): 426 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) 427 attr = DenseElementsAttr.get(array, signless=False) 428 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32> 429 print(attr) 430 # CHECK: {{\[}}[1 2 3] 431 # CHECK: {{\[}}4 5 6]] 432 print(np.array(attr)) 433 434 435## 64bit integer arrays 436# CHECK-LABEL: TEST: testGetDenseElementsI64Signless 437@run 438def testGetDenseElementsI64Signless(): 439 with Context(): 440 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 441 attr = DenseElementsAttr.get(array) 442 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> 443 print(attr) 444 # CHECK: {{\[}}[1 2 3] 445 # CHECK: {{\[}}4 5 6]] 446 print(np.array(attr)) 447 448 449# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless 450@run 451def testGetDenseElementsUI64Signless(): 452 with Context(): 453 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) 454 attr = DenseElementsAttr.get(array) 455 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> 456 print(attr) 457 # CHECK: {{\[}}[1 2 3] 458 # CHECK: {{\[}}4 5 6]] 459 print(np.array(attr)) 460 461 462# CHECK-LABEL: TEST: testGetDenseElementsI64 463@run 464def testGetDenseElementsI64(): 465 with Context(): 466 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 467 attr = DenseElementsAttr.get(array, signless=False) 468 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64> 469 print(attr) 470 # CHECK: {{\[}}[1 2 3] 471 # CHECK: {{\[}}4 5 6]] 472 print(np.array(attr)) 473 474 475# CHECK-LABEL: TEST: testGetDenseElementsUI64 476@run 477def testGetDenseElementsUI64(): 478 with Context(): 479 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) 480 attr = DenseElementsAttr.get(array, signless=False) 481 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64> 482 print(attr) 483 # CHECK: {{\[}}[1 2 3] 484 # CHECK: {{\[}}4 5 6]] 485 print(np.array(attr)) 486 487 488# CHECK-LABEL: TEST: testGetDenseElementsIndex 489@run 490def testGetDenseElementsIndex(): 491 with Context(), Location.unknown(): 492 idx_type = IndexType.get() 493 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 494 attr = DenseElementsAttr.get(array, type=idx_type) 495 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex> 496 print(attr) 497 arr = np.array(attr) 498 # CHECK: {{\[}}[1 2 3] 499 # CHECK: {{\[}}4 5 6]] 500 print(arr) 501 # CHECK: True 502 print(arr.dtype == np.int64) 503 504 505# CHECK-LABEL: TEST: testGetDenseResourceElementsAttr 506@run 507def testGetDenseResourceElementsAttr(): 508 def on_delete(_): 509 print("BACKING MEMORY DELETED") 510 511 context = Context() 512 mview = memoryview(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) 513 ref = weakref.ref(mview, on_delete) 514 515 def test_attribute(context, mview): 516 with context, Location.unknown(): 517 element_type = IntegerType.get_signless(32) 518 tensor_type = RankedTensorType.get((2, 3), element_type) 519 resource = DenseResourceElementsAttr.get_from_buffer( 520 mview, "from_py", tensor_type 521 ) 522 module = Module.parse("module {}") 523 module.operation.attributes["test.resource"] = resource 524 # CHECK: test.resource = dense_resource<from_py> : tensor<2x3xi32> 525 # CHECK: from_py: "0x04000000010000000200000003000000040000000500000006000000" 526 print(module) 527 528 # Verifies type casting. 529 # CHECK: dense_resource<from_py> : tensor<2x3xi32> 530 print( 531 DenseResourceElementsAttr(module.operation.attributes["test.resource"]) 532 ) 533 534 test_attribute(context, mview) 535 mview = None 536 gc.collect() 537 # CHECK: FREEING CONTEXT 538 print("FREEING CONTEXT") 539 context = None 540 gc.collect() 541 # CHECK: BACKING MEMORY DELETED 542 # CHECK: EXIT FUNCTION 543 print("EXIT FUNCTION") 544