1# RUN: %PYTHON %s | FileCheck %s 2# Note that this is separate from ir_attributes.py since it depends on numpy, 3# and we may want to disable if not available. 4 5import gc 6from mlir.ir import * 7import numpy as np 8 9 10def run(f): 11 print("\nTEST:", f.__name__) 12 f() 13 gc.collect() 14 assert Context._get_live_count() == 0 15 return f 16 17 18################################################################################ 19# Tests of the array/buffer .get() factory method on unsupported dtype. 20################################################################################ 21 22 23@run 24def testGetDenseElementsUnsupported(): 25 with Context(): 26 array = np.array([["hello", "goodbye"]]) 27 try: 28 attr = DenseElementsAttr.get(array) 29 except ValueError as e: 30 # CHECK: unimplemented array format conversion from format: 31 print(e) 32 33# CHECK-LABEL: TEST: testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided 34@run 35def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided(): 36 with Context(): 37 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 38 # datetime64 specifically isn't important: it's just a 64-bit type that 39 # doesn't have a format under the Python buffer protocol. A more 40 # realistic example would be a NumPy extension type like the bfloat16 41 # type from the ml_dtypes package, which isn't a dependency of this 42 # test. 43 attr = DenseElementsAttr.get(array.view(np.datetime64), 44 type=IntegerType.get_signless(64)) 45 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> 46 print(attr) 47 # CHECK: {{\[}}[1 2 3] 48 # CHECK: {{\[}}4 5 6]] 49 print(np.array(attr)) 50 51 52################################################################################ 53# Splats. 54################################################################################ 55 56# CHECK-LABEL: TEST: testGetDenseElementsSplatInt 57@run 58def testGetDenseElementsSplatInt(): 59 with Context(), Location.unknown(): 60 t = IntegerType.get_signless(32) 61 element = IntegerAttr.get(t, 555) 62 shaped_type = RankedTensorType.get((2, 3, 4), t) 63 attr = DenseElementsAttr.get_splat(shaped_type, element) 64 # CHECK: dense<555> : tensor<2x3x4xi32> 65 print(attr) 66 # CHECK: is_splat: True 67 print("is_splat:", attr.is_splat) 68 69 # CHECK: splat_value: IntegerAttr(555 : i32) 70 splat_value = attr.get_splat_value() 71 print("splat_value:", repr(splat_value)) 72 assert splat_value == element 73 74 75# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat 76@run 77def testGetDenseElementsSplatFloat(): 78 with Context(), Location.unknown(): 79 t = F32Type.get() 80 element = FloatAttr.get(t, 1.2) 81 shaped_type = RankedTensorType.get((2, 3, 4), t) 82 attr = DenseElementsAttr.get_splat(shaped_type, element) 83 # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32> 84 print(attr) 85 assert attr.get_splat_value() == element 86 87 88# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors 89@run 90def testGetDenseElementsSplatErrors(): 91 with Context(), Location.unknown(): 92 t = F32Type.get() 93 other_t = F64Type.get() 94 element = FloatAttr.get(t, 1.2) 95 other_element = FloatAttr.get(other_t, 1.2) 96 shaped_type = RankedTensorType.get((2, 3, 4), t) 97 dynamic_shaped_type = UnrankedTensorType.get(t) 98 non_shaped_type = t 99 100 try: 101 attr = DenseElementsAttr.get_splat(non_shaped_type, element) 102 except ValueError as e: 103 # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32) 104 print(e) 105 106 try: 107 attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element) 108 except ValueError as e: 109 # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>) 110 print(e) 111 112 try: 113 attr = DenseElementsAttr.get_splat(shaped_type, other_element) 114 except ValueError as e: 115 # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64) 116 print(e) 117 118 119# CHECK-LABEL: TEST: testRepeatedValuesSplat 120@run 121def testRepeatedValuesSplat(): 122 with Context(): 123 array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32) 124 attr = DenseElementsAttr.get(array) 125 # CHECK: dense<1.000000e+00> : tensor<2x3xf32> 126 print(attr) 127 # CHECK: is_splat: True 128 print("is_splat:", attr.is_splat) 129 # CHECK{LITERAL}: [[1. 1. 1.] 130 # CHECK{LITERAL}: [1. 1. 1.]] 131 print(np.array(attr)) 132 133 134# CHECK-LABEL: TEST: testNonSplat 135@run 136def testNonSplat(): 137 with Context(): 138 array = np.array([2.0, 1.0, 1.0], dtype=np.float32) 139 attr = DenseElementsAttr.get(array) 140 # CHECK: is_splat: False 141 print("is_splat:", attr.is_splat) 142 143 144################################################################################ 145# Tests of the array/buffer .get() factory method, in all of its permutations. 146################################################################################ 147 148### explicitly provided types 149 150 151@run 152def testGetDenseElementsBF16(): 153 with Context(): 154 array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16) 155 attr = DenseElementsAttr.get(array, type=BF16Type.get()) 156 # Note: These values don't mean much since just bit-casting. But they 157 # shouldn't change. 158 # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16> 159 print(attr) 160 161 162@run 163def testGetDenseElementsInteger4(): 164 with Context(): 165 array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8) 166 attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4)) 167 # Note: These values don't mean much since just bit-casting. But they 168 # shouldn't change. 169 # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4> 170 print(attr) 171 172 173@run 174def testGetDenseElementsBool(): 175 with Context(): 176 bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_) 177 array = np.packbits(bool_array, axis=None, bitorder="little") 178 attr = DenseElementsAttr.get( 179 array, type=IntegerType.get_signless(1), shape=bool_array.shape 180 ) 181 # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1> 182 print(attr) 183 184 185@run 186def testGetDenseElementsBoolSplat(): 187 with Context(): 188 zero = np.array(0, dtype=np.uint8) 189 one = np.array(255, dtype=np.uint8) 190 print(one) 191 # CHECK: dense<false> : tensor<4x2x5xi1> 192 print( 193 DenseElementsAttr.get( 194 zero, type=IntegerType.get_signless(1), shape=(4, 2, 5) 195 ) 196 ) 197 # CHECK: dense<true> : tensor<4x2x5xi1> 198 print( 199 DenseElementsAttr.get( 200 one, type=IntegerType.get_signless(1), shape=(4, 2, 5) 201 ) 202 ) 203 204 205### float and double arrays. 206 207# CHECK-LABEL: TEST: testGetDenseElementsF16 208@run 209def testGetDenseElementsF16(): 210 with Context(): 211 array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16) 212 attr = DenseElementsAttr.get(array) 213 # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16> 214 print(attr) 215 # CHECK: {{\[}}[ 2. 4. 8.] 216 # CHECK: {{\[}}16. 32. 64.]] 217 print(np.array(attr)) 218 219 220# CHECK-LABEL: TEST: testGetDenseElementsF32 221@run 222def testGetDenseElementsF32(): 223 with Context(): 224 array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32) 225 attr = DenseElementsAttr.get(array) 226 # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32> 227 print(attr) 228 # CHECK: {{\[}}[1.1 2.2 3.3] 229 # CHECK: {{\[}}4.4 5.5 6.6]] 230 print(np.array(attr)) 231 232 233# CHECK-LABEL: TEST: testGetDenseElementsF64 234@run 235def testGetDenseElementsF64(): 236 with Context(): 237 array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64) 238 attr = DenseElementsAttr.get(array) 239 # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64> 240 print(attr) 241 # CHECK: {{\[}}[1.1 2.2 3.3] 242 # CHECK: {{\[}}4.4 5.5 6.6]] 243 print(np.array(attr)) 244 245 246### 16 bit integer arrays 247# CHECK-LABEL: TEST: testGetDenseElementsI16Signless 248@run 249def testGetDenseElementsI16Signless(): 250 with Context(): 251 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) 252 attr = DenseElementsAttr.get(array) 253 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> 254 print(attr) 255 # CHECK: {{\[}}[1 2 3] 256 # CHECK: {{\[}}4 5 6]] 257 print(np.array(attr)) 258 259 260# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless 261@run 262def testGetDenseElementsUI16Signless(): 263 with Context(): 264 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) 265 attr = DenseElementsAttr.get(array) 266 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> 267 print(attr) 268 # CHECK: {{\[}}[1 2 3] 269 # CHECK: {{\[}}4 5 6]] 270 print(np.array(attr)) 271 272 273# CHECK-LABEL: TEST: testGetDenseElementsI16 274@run 275def testGetDenseElementsI16(): 276 with Context(): 277 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) 278 attr = DenseElementsAttr.get(array, signless=False) 279 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16> 280 print(attr) 281 # CHECK: {{\[}}[1 2 3] 282 # CHECK: {{\[}}4 5 6]] 283 print(np.array(attr)) 284 285 286# CHECK-LABEL: TEST: testGetDenseElementsUI16 287@run 288def testGetDenseElementsUI16(): 289 with Context(): 290 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) 291 attr = DenseElementsAttr.get(array, signless=False) 292 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16> 293 print(attr) 294 # CHECK: {{\[}}[1 2 3] 295 # CHECK: {{\[}}4 5 6]] 296 print(np.array(attr)) 297 298 299### 32 bit integer arrays 300# CHECK-LABEL: TEST: testGetDenseElementsI32Signless 301@run 302def testGetDenseElementsI32Signless(): 303 with Context(): 304 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) 305 attr = DenseElementsAttr.get(array) 306 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 307 print(attr) 308 # CHECK: {{\[}}[1 2 3] 309 # CHECK: {{\[}}4 5 6]] 310 print(np.array(attr)) 311 312 313# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless 314@run 315def testGetDenseElementsUI32Signless(): 316 with Context(): 317 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) 318 attr = DenseElementsAttr.get(array) 319 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 320 print(attr) 321 # CHECK: {{\[}}[1 2 3] 322 # CHECK: {{\[}}4 5 6]] 323 print(np.array(attr)) 324 325 326# CHECK-LABEL: TEST: testGetDenseElementsI32 327@run 328def testGetDenseElementsI32(): 329 with Context(): 330 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) 331 attr = DenseElementsAttr.get(array, signless=False) 332 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32> 333 print(attr) 334 # CHECK: {{\[}}[1 2 3] 335 # CHECK: {{\[}}4 5 6]] 336 print(np.array(attr)) 337 338 339# CHECK-LABEL: TEST: testGetDenseElementsUI32 340@run 341def testGetDenseElementsUI32(): 342 with Context(): 343 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) 344 attr = DenseElementsAttr.get(array, signless=False) 345 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32> 346 print(attr) 347 # CHECK: {{\[}}[1 2 3] 348 # CHECK: {{\[}}4 5 6]] 349 print(np.array(attr)) 350 351 352## 64bit integer arrays 353# CHECK-LABEL: TEST: testGetDenseElementsI64Signless 354@run 355def testGetDenseElementsI64Signless(): 356 with Context(): 357 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 358 attr = DenseElementsAttr.get(array) 359 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> 360 print(attr) 361 # CHECK: {{\[}}[1 2 3] 362 # CHECK: {{\[}}4 5 6]] 363 print(np.array(attr)) 364 365 366# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless 367@run 368def testGetDenseElementsUI64Signless(): 369 with Context(): 370 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) 371 attr = DenseElementsAttr.get(array) 372 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> 373 print(attr) 374 # CHECK: {{\[}}[1 2 3] 375 # CHECK: {{\[}}4 5 6]] 376 print(np.array(attr)) 377 378 379# CHECK-LABEL: TEST: testGetDenseElementsI64 380@run 381def testGetDenseElementsI64(): 382 with Context(): 383 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 384 attr = DenseElementsAttr.get(array, signless=False) 385 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64> 386 print(attr) 387 # CHECK: {{\[}}[1 2 3] 388 # CHECK: {{\[}}4 5 6]] 389 print(np.array(attr)) 390 391 392# CHECK-LABEL: TEST: testGetDenseElementsUI64 393@run 394def testGetDenseElementsUI64(): 395 with Context(): 396 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) 397 attr = DenseElementsAttr.get(array, signless=False) 398 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64> 399 print(attr) 400 # CHECK: {{\[}}[1 2 3] 401 # CHECK: {{\[}}4 5 6]] 402 print(np.array(attr)) 403 404 405# CHECK-LABEL: TEST: testGetDenseElementsIndex 406@run 407def testGetDenseElementsIndex(): 408 with Context(), Location.unknown(): 409 idx_type = IndexType.get() 410 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 411 attr = DenseElementsAttr.get(array, type=idx_type) 412 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex> 413 print(attr) 414 arr = np.array(attr) 415 # CHECK: {{\[}}[1 2 3] 416 # CHECK: {{\[}}4 5 6]] 417 print(arr) 418 # CHECK: True 419 print(arr.dtype == np.int64) 420