19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s 29f3f6d7bSStella Laurenzo 3404af14fSMaksim Leventalimport mlir.dialects.arith as arith 49f3f6d7bSStella Laurenzoimport mlir.dialects.memref as memref 583be8a74SMaksim Leventalimport mlir.extras.types as T 6404af14fSMaksim Leventalfrom mlir.dialects.memref import _infer_memref_subview_result_type 7404af14fSMaksim Leventalfrom mlir.ir import * 89f3f6d7bSStella Laurenzo 9a54f4eaeSMogball 109f3f6d7bSStella Laurenzodef run(f): 119f3f6d7bSStella Laurenzo print("\nTEST:", f.__name__) 129f3f6d7bSStella Laurenzo f() 137fd6f40dSAlex Zinenko return f 149f3f6d7bSStella Laurenzo 15a54f4eaeSMogball 169f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testSubViewAccessors 177fd6f40dSAlex Zinenko@run 189f3f6d7bSStella Laurenzodef testSubViewAccessors(): 199f3f6d7bSStella Laurenzo ctx = Context() 20a54f4eaeSMogball module = Module.parse( 21a54f4eaeSMogball r""" 222310ced8SRiver Riddle func.func @f1(%arg0: memref<?x?xf32>) { 23a54f4eaeSMogball %0 = arith.constant 0 : index 24a54f4eaeSMogball %1 = arith.constant 1 : index 25a54f4eaeSMogball %2 = arith.constant 2 : index 26a54f4eaeSMogball %3 = arith.constant 3 : index 27a54f4eaeSMogball %4 = arith.constant 4 : index 28a54f4eaeSMogball %5 = arith.constant 5 : index 29519847feSAlex Zinenko memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>> 309f3f6d7bSStella Laurenzo return 319f3f6d7bSStella Laurenzo } 32f9008e63STobias Hieta """, 33f9008e63STobias Hieta ctx, 34f9008e63STobias Hieta ) 359f3f6d7bSStella Laurenzo func_body = module.body.operations[0].regions[0].blocks[0] 369f3f6d7bSStella Laurenzo subview = func_body.operations[6] 379f3f6d7bSStella Laurenzo 389f3f6d7bSStella Laurenzo assert subview.source == subview.operands[0] 399f3f6d7bSStella Laurenzo assert len(subview.offsets) == 2 409f3f6d7bSStella Laurenzo assert len(subview.sizes) == 2 419f3f6d7bSStella Laurenzo assert len(subview.strides) == 2 429f3f6d7bSStella Laurenzo assert subview.result == subview.results[0] 439f3f6d7bSStella Laurenzo 449f3f6d7bSStella Laurenzo # CHECK: SubViewOp 459f3f6d7bSStella Laurenzo print(type(subview).__name__) 469f3f6d7bSStella Laurenzo 479f3f6d7bSStella Laurenzo # CHECK: constant 0 489f3f6d7bSStella Laurenzo print(subview.offsets[0]) 499f3f6d7bSStella Laurenzo # CHECK: constant 1 509f3f6d7bSStella Laurenzo print(subview.offsets[1]) 519f3f6d7bSStella Laurenzo # CHECK: constant 2 529f3f6d7bSStella Laurenzo print(subview.sizes[0]) 539f3f6d7bSStella Laurenzo # CHECK: constant 3 549f3f6d7bSStella Laurenzo print(subview.sizes[1]) 559f3f6d7bSStella Laurenzo # CHECK: constant 4 569f3f6d7bSStella Laurenzo print(subview.strides[0]) 579f3f6d7bSStella Laurenzo # CHECK: constant 5 589f3f6d7bSStella Laurenzo print(subview.strides[1]) 599f3f6d7bSStella Laurenzo 609f3f6d7bSStella Laurenzo 617fd6f40dSAlex Zinenko# CHECK-LABEL: TEST: testCustomBuidlers 627fd6f40dSAlex Zinenko@run 637fd6f40dSAlex Zinenkodef testCustomBuidlers(): 647fd6f40dSAlex Zinenko with Context() as ctx, Location.unknown(ctx): 65f9008e63STobias Hieta module = Module.parse( 66f9008e63STobias Hieta r""" 672310ced8SRiver Riddle func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) { 687fd6f40dSAlex Zinenko return 697fd6f40dSAlex Zinenko } 70f9008e63STobias Hieta """ 71f9008e63STobias Hieta ) 7223aa5a74SRiver Riddle f = module.body.operations[0] 7323aa5a74SRiver Riddle func_body = f.regions[0].blocks[0] 747fd6f40dSAlex Zinenko with InsertionPoint.at_block_terminator(func_body): 7523aa5a74SRiver Riddle memref.LoadOp(f.arguments[0], f.arguments[1:]) 767fd6f40dSAlex Zinenko 777fd6f40dSAlex Zinenko # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) 787fd6f40dSAlex Zinenko # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] 797fd6f40dSAlex Zinenko print(module) 80a04c0b7eSAlex Zinenko assert module.operation.verify() 8183be8a74SMaksim Levental 8283be8a74SMaksim Levental 8383be8a74SMaksim Levental# CHECK-LABEL: TEST: testMemRefAttr 8483be8a74SMaksim Levental@run 8583be8a74SMaksim Leventaldef testMemRefAttr(): 8683be8a74SMaksim Levental with Context() as ctx, Location.unknown(ctx): 8783be8a74SMaksim Levental module = Module.create() 8883be8a74SMaksim Levental with InsertionPoint(module.body): 8983be8a74SMaksim Levental memref.global_("objFifo_in0", T.memref(16, T.i32())) 9083be8a74SMaksim Levental # CHECK: memref.global @objFifo_in0 : memref<16xi32> 9183be8a74SMaksim Levental print(module) 92404af14fSMaksim Levental 93404af14fSMaksim Levental 94404af14fSMaksim Levental# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics 95404af14fSMaksim Levental@run 96404af14fSMaksim Leventaldef testSubViewOpInferReturnTypeSemantics(): 97404af14fSMaksim Levental with Context() as ctx, Location.unknown(ctx): 98404af14fSMaksim Levental module = Module.create() 99404af14fSMaksim Levental with InsertionPoint(module.body): 100404af14fSMaksim Levental x = memref.alloc(T.memref(10, 10, T.i32()), [], []) 101404af14fSMaksim Levental # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32> 102404af14fSMaksim Levental print(x.owner) 103404af14fSMaksim Levental 104404af14fSMaksim Levental y = memref.subview(x, [1, 1], [3, 3], [1, 1]) 105404af14fSMaksim Levental assert y.owner.verify() 106404af14fSMaksim Levental # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>> 107404af14fSMaksim Levental print(y.owner) 108404af14fSMaksim Levental 109404af14fSMaksim Levental z = memref.subview( 110404af14fSMaksim Levental x, 111404af14fSMaksim Levental [arith.constant(T.index(), 1), 1], 112404af14fSMaksim Levental [3, 3], 113404af14fSMaksim Levental [1, 1], 114404af14fSMaksim Levental ) 115404af14fSMaksim Levental # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>> 116404af14fSMaksim Levental print(z.owner) 117404af14fSMaksim Levental 118404af14fSMaksim Levental z = memref.subview( 119404af14fSMaksim Levental x, 120404af14fSMaksim Levental [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], 121404af14fSMaksim Levental [3, 3], 122404af14fSMaksim Levental [1, 1], 123404af14fSMaksim Levental ) 124404af14fSMaksim Levental # CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>> 125404af14fSMaksim Levental print(z.owner) 126404af14fSMaksim Levental 127404af14fSMaksim Levental s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4)) 128404af14fSMaksim Levental z = memref.subview( 129404af14fSMaksim Levental x, 130404af14fSMaksim Levental [s, 0], 131404af14fSMaksim Levental [3, 3], 132404af14fSMaksim Levental [1, 1], 133404af14fSMaksim Levental ) 134404af14fSMaksim Levental # CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>> 135404af14fSMaksim Levental print(z) 136404af14fSMaksim Levental 137404af14fSMaksim Levental try: 138404af14fSMaksim Levental _infer_memref_subview_result_type( 139404af14fSMaksim Levental x.type, 140404af14fSMaksim Levental [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], 141404af14fSMaksim Levental [ShapedType.get_dynamic_size(), 3], 142404af14fSMaksim Levental [1, 1], 143404af14fSMaksim Levental ) 144404af14fSMaksim Levental except ValueError as e: 145404af14fSMaksim Levental # CHECK: Only inferring from python or mlir integer constant is supported 146404af14fSMaksim Levental print(e) 147404af14fSMaksim Levental 148404af14fSMaksim Levental try: 149404af14fSMaksim Levental memref.subview( 150404af14fSMaksim Levental x, 151404af14fSMaksim Levental [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], 152404af14fSMaksim Levental [ShapedType.get_dynamic_size(), 3], 153404af14fSMaksim Levental [1, 1], 154404af14fSMaksim Levental ) 155404af14fSMaksim Levental except ValueError as e: 156404af14fSMaksim Levental # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type 157404af14fSMaksim Levental print(e) 158404af14fSMaksim Levental 159404af14fSMaksim Levental layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1]) 160404af14fSMaksim Levental x = memref.alloc( 161404af14fSMaksim Levental T.memref( 162404af14fSMaksim Levental 10, 163404af14fSMaksim Levental 10, 164404af14fSMaksim Levental T.i32(), 165404af14fSMaksim Levental layout=layout, 166404af14fSMaksim Levental ), 167404af14fSMaksim Levental [], 168404af14fSMaksim Levental [arith.constant(T.index(), 42)], 169404af14fSMaksim Levental ) 170404af14fSMaksim Levental # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>> 171404af14fSMaksim Levental print(x.owner) 172404af14fSMaksim Levental y = memref.subview( 173404af14fSMaksim Levental x, 174404af14fSMaksim Levental [1, 1], 175404af14fSMaksim Levental [3, 3], 176404af14fSMaksim Levental [1, 1], 177404af14fSMaksim Levental result_type=T.memref(3, 3, T.i32(), layout=layout), 178404af14fSMaksim Levental ) 179404af14fSMaksim Levental # CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>> 180404af14fSMaksim Levental print(y.owner) 181404af14fSMaksim Levental 182404af14fSMaksim Levental 183404af14fSMaksim Levental# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing 184404af14fSMaksim Levental@run 185404af14fSMaksim Leventaldef testSubViewOpInferReturnTypeExtensiveSlicing(): 186404af14fSMaksim Levental def check_strides_offset(memref, np_view): 187404af14fSMaksim Levental layout = memref.type.layout 188404af14fSMaksim Levental dtype_size_in_bytes = np_view.dtype.itemsize 189404af14fSMaksim Levental golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist() 190404af14fSMaksim Levental golden_offset = ( 191404af14fSMaksim Levental np_view.ctypes.data - np_view.base.ctypes.data 192404af14fSMaksim Levental ) // dtype_size_in_bytes 193404af14fSMaksim Levental 194404af14fSMaksim Levental assert (layout.strides, layout.offset) == (golden_strides, golden_offset) 195404af14fSMaksim Levental 196404af14fSMaksim Levental with Context() as ctx, Location.unknown(ctx): 197404af14fSMaksim Levental module = Module.create() 198404af14fSMaksim Levental with InsertionPoint(module.body): 199*59eadcd2SJacques Pienaar shape = (10, 22, 3, 44) 200404af14fSMaksim Levental golden_mem = np.zeros(shape, dtype=np.int32) 201404af14fSMaksim Levental mem1 = memref.alloc(T.memref(*shape, T.i32()), [], []) 202404af14fSMaksim Levental 203404af14fSMaksim Levental # fmt: off 204*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, ...]) 205*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 3, 44), (1, 1, 1, 1)), golden_mem[:, 1:2]) 206*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 44), (1, 1, 1, 1)), golden_mem[:, :, 1:2]) 207*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 3, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2]) 208*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 3, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2]) 209*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2]) 210*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :]) 211404af14fSMaksim Levental check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2]) 212*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 44), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :]) 213*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :]) 214*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2]) 215404af14fSMaksim Levental check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2]) 216404af14fSMaksim Levental check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2]) 217*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :]) 218404af14fSMaksim Levental # fmt: on 219404af14fSMaksim Levental 220404af14fSMaksim Levental # default strides and offset means no stridedlayout attribute means affinemap layout 221404af14fSMaksim Levental assert memref.subview( 222*59eadcd2SJacques Pienaar mem1, (0, 0, 0, 0), (10, 22, 3, 44), (1, 1, 1, 1) 223404af14fSMaksim Levental ).type.layout == AffineMapAttr.get( 224404af14fSMaksim Levental AffineMap.get( 225404af14fSMaksim Levental 4, 226404af14fSMaksim Levental 0, 227404af14fSMaksim Levental [ 228404af14fSMaksim Levental AffineDimExpr.get(0), 229404af14fSMaksim Levental AffineDimExpr.get(1), 230404af14fSMaksim Levental AffineDimExpr.get(2), 231404af14fSMaksim Levental AffineDimExpr.get(3), 232404af14fSMaksim Levental ], 233404af14fSMaksim Levental ) 234404af14fSMaksim Levental ) 235404af14fSMaksim Levental 236*59eadcd2SJacques Pienaar shape = (7, 22, 30, 44) 237404af14fSMaksim Levental golden_mem = np.zeros(shape, dtype=np.int32) 238404af14fSMaksim Levental mem2 = memref.alloc(T.memref(*shape, T.i32()), [], []) 239404af14fSMaksim Levental # fmt: off 240*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 3, 44), (1, 2, 1, 1)), golden_mem[:, 0:22:2]) 241*59eadcd2SJacques Pienaar check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 44), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30]) 242404af14fSMaksim Levental check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400]) 243404af14fSMaksim Levental # fmt: on 244404af14fSMaksim Levental 245404af14fSMaksim Levental shape = (8, 8) 246404af14fSMaksim Levental golden_mem = np.zeros(shape, dtype=np.int32) 247404af14fSMaksim Levental # fmt: off 248404af14fSMaksim Levental mem3 = memref.alloc(T.memref(*shape, T.i32()), [], []) 249404af14fSMaksim Levental check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4]) 250404af14fSMaksim Levental check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8]) 251404af14fSMaksim Levental # fmt: on 252