1# RUN: %PYTHON %s | FileCheck %s 2 3import mlir.dialects.arith as arith 4import mlir.dialects.memref as memref 5import mlir.extras.types as T 6from mlir.dialects.memref import _infer_memref_subview_result_type 7from mlir.ir import * 8 9 10def run(f): 11 print("\nTEST:", f.__name__) 12 f() 13 return f 14 15 16# CHECK-LABEL: TEST: testSubViewAccessors 17@run 18def testSubViewAccessors(): 19 ctx = Context() 20 module = Module.parse( 21 r""" 22 func.func @f1(%arg0: memref<?x?xf32>) { 23 %0 = arith.constant 0 : index 24 %1 = arith.constant 1 : index 25 %2 = arith.constant 2 : index 26 %3 = arith.constant 3 : index 27 %4 = arith.constant 4 : index 28 %5 = arith.constant 5 : index 29 memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>> 30 return 31 } 32 """, 33 ctx, 34 ) 35 func_body = module.body.operations[0].regions[0].blocks[0] 36 subview = func_body.operations[6] 37 38 assert subview.source == subview.operands[0] 39 assert len(subview.offsets) == 2 40 assert len(subview.sizes) == 2 41 assert len(subview.strides) == 2 42 assert subview.result == subview.results[0] 43 44 # CHECK: SubViewOp 45 print(type(subview).__name__) 46 47 # CHECK: constant 0 48 print(subview.offsets[0]) 49 # CHECK: constant 1 50 print(subview.offsets[1]) 51 # CHECK: constant 2 52 print(subview.sizes[0]) 53 # CHECK: constant 3 54 print(subview.sizes[1]) 55 # CHECK: constant 4 56 print(subview.strides[0]) 57 # CHECK: constant 5 58 print(subview.strides[1]) 59 60 61# CHECK-LABEL: TEST: testCustomBuidlers 62@run 63def testCustomBuidlers(): 64 with Context() as ctx, Location.unknown(ctx): 65 module = Module.parse( 66 r""" 67 func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) { 68 return 69 } 70 """ 71 ) 72 f = module.body.operations[0] 73 func_body = f.regions[0].blocks[0] 74 with InsertionPoint.at_block_terminator(func_body): 75 memref.LoadOp(f.arguments[0], f.arguments[1:]) 76 77 # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) 78 # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] 79 print(module) 80 assert module.operation.verify() 81 82 83# CHECK-LABEL: TEST: testMemRefAttr 84@run 85def testMemRefAttr(): 86 with Context() as ctx, Location.unknown(ctx): 87 module = Module.create() 88 with InsertionPoint(module.body): 89 memref.global_("objFifo_in0", T.memref(16, T.i32())) 90 # CHECK: memref.global @objFifo_in0 : memref<16xi32> 91 print(module) 92 93 94# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics 95@run 96def testSubViewOpInferReturnTypeSemantics(): 97 with Context() as ctx, Location.unknown(ctx): 98 module = Module.create() 99 with InsertionPoint(module.body): 100 x = memref.alloc(T.memref(10, 10, T.i32()), [], []) 101 # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32> 102 print(x.owner) 103 104 y = memref.subview(x, [1, 1], [3, 3], [1, 1]) 105 assert y.owner.verify() 106 # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>> 107 print(y.owner) 108 109 z = memref.subview( 110 x, 111 [arith.constant(T.index(), 1), 1], 112 [3, 3], 113 [1, 1], 114 ) 115 # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>> 116 print(z.owner) 117 118 z = memref.subview( 119 x, 120 [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], 121 [3, 3], 122 [1, 1], 123 ) 124 # CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>> 125 print(z.owner) 126 127 s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4)) 128 z = memref.subview( 129 x, 130 [s, 0], 131 [3, 3], 132 [1, 1], 133 ) 134 # CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>> 135 print(z) 136 137 try: 138 _infer_memref_subview_result_type( 139 x.type, 140 [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], 141 [ShapedType.get_dynamic_size(), 3], 142 [1, 1], 143 ) 144 except ValueError as e: 145 # CHECK: Only inferring from python or mlir integer constant is supported 146 print(e) 147 148 try: 149 memref.subview( 150 x, 151 [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], 152 [ShapedType.get_dynamic_size(), 3], 153 [1, 1], 154 ) 155 except ValueError as e: 156 # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type 157 print(e) 158 159 layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1]) 160 x = memref.alloc( 161 T.memref( 162 10, 163 10, 164 T.i32(), 165 layout=layout, 166 ), 167 [], 168 [arith.constant(T.index(), 42)], 169 ) 170 # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>> 171 print(x.owner) 172 y = memref.subview( 173 x, 174 [1, 1], 175 [3, 3], 176 [1, 1], 177 result_type=T.memref(3, 3, T.i32(), layout=layout), 178 ) 179 # CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>> 180 print(y.owner) 181 182 183# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing 184@run 185def testSubViewOpInferReturnTypeExtensiveSlicing(): 186 def check_strides_offset(memref, np_view): 187 layout = memref.type.layout 188 dtype_size_in_bytes = np_view.dtype.itemsize 189 golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist() 190 golden_offset = ( 191 np_view.ctypes.data - np_view.base.ctypes.data 192 ) // dtype_size_in_bytes 193 194 assert (layout.strides, layout.offset) == (golden_strides, golden_offset) 195 196 with Context() as ctx, Location.unknown(ctx): 197 module = Module.create() 198 with InsertionPoint(module.body): 199 shape = (10, 22, 3, 44) 200 golden_mem = np.zeros(shape, dtype=np.int32) 201 mem1 = memref.alloc(T.memref(*shape, T.i32()), [], []) 202 203 # fmt: off 204 check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, ...]) 205 check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 3, 44), (1, 1, 1, 1)), golden_mem[:, 1:2]) 206 check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 44), (1, 1, 1, 1)), golden_mem[:, :, 1:2]) 207 check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 3, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2]) 208 check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 3, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2]) 209 check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2]) 210 check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :]) 211 check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2]) 212 check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 44), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :]) 213 check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :]) 214 check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2]) 215 check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2]) 216 check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2]) 217 check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :]) 218 # fmt: on 219 220 # default strides and offset means no stridedlayout attribute means affinemap layout 221 assert memref.subview( 222 mem1, (0, 0, 0, 0), (10, 22, 3, 44), (1, 1, 1, 1) 223 ).type.layout == AffineMapAttr.get( 224 AffineMap.get( 225 4, 226 0, 227 [ 228 AffineDimExpr.get(0), 229 AffineDimExpr.get(1), 230 AffineDimExpr.get(2), 231 AffineDimExpr.get(3), 232 ], 233 ) 234 ) 235 236 shape = (7, 22, 30, 44) 237 golden_mem = np.zeros(shape, dtype=np.int32) 238 mem2 = memref.alloc(T.memref(*shape, T.i32()), [], []) 239 # fmt: off 240 check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 3, 44), (1, 2, 1, 1)), golden_mem[:, 0:22:2]) 241 check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 44), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30]) 242 check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400]) 243 # fmt: on 244 245 shape = (8, 8) 246 golden_mem = np.zeros(shape, dtype=np.int32) 247 # fmt: off 248 mem3 = memref.alloc(T.memref(*shape, T.i32()), [], []) 249 check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4]) 250 check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8]) 251 # fmt: on 252