1// RUN: mlir-opt %s -canonicalize="test-convergence" --split-input-file -allow-unregistered-dialect | FileCheck %s 2 3 4// CHECK-LABEL: collapse_shape_identity_fold 5// CHECK-NEXT: return 6func.func @collapse_shape_identity_fold(%arg0 : memref<5xi8>) -> memref<5xi8> { 7 %0 = memref.collapse_shape %arg0 [[0]] : memref<5xi8> into memref<5xi8> 8 return %0 : memref<5xi8> 9} 10 11// ----- 12 13// CHECK-LABEL: expand_shape_identity_fold 14// CHECK-NEXT: return 15func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> { 16 %0 = memref.expand_shape %arg0 [[0], [1]] output_shape [5, 4] : memref<5x4xi8> into memref<5x4xi8> 17 return %0 : memref<5x4xi8> 18} 19 20// ----- 21 22// CHECK-LABEL: collapse_expand_rank0_cancel 23// CHECK-NEXT: return 24func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8> { 25 %0 = memref.collapse_shape %arg0 [] : memref<1x1xi8> into memref<i8> 26 %1 = memref.expand_shape %0 [] output_shape [1, 1] : memref<i8> into memref<1x1xi8> 27 return %1 : memref<1x1xi8> 28} 29 30// ----- 31 32// CHECK-LABEL: func @subview_of_size_memcast 33// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> 34// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}> 35// CHECK: return %[[S]] : memref<16x32xi8, strided{{.*}}> 36func.func @subview_of_size_memcast(%arg : memref<4x6x16x32xi8>) -> 37 memref<16x32xi8, strided<[32, 1], offset: 512>>{ 38 %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8> 39 %1 = memref.subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : 40 memref<?x?x16x32xi8> to 41 memref<16x32xi8, strided<[32, 1], offset: 512>> 42 return %1 : memref<16x32xi8, strided<[32, 1], offset: 512>> 43} 44 45// ----- 46 47// CHECK: func @subview_of_strides_memcast 48// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<1x1x?xf32, strided{{.*}}> 49// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 0, 0] [1, 1, 4] 50// CHECK-SAME: to memref<1x4xf32, strided<[7, 1], offset: ?>> 51// CHECK: %[[M:.+]] = memref.cast %[[S]] 52// CHECK-SAME: to memref<1x4xf32, strided<[?, ?], offset: ?>> 53// CHECK: return %[[M]] 54func.func @subview_of_strides_memcast(%arg : memref<1x1x?xf32, strided<[35, 7, 1], offset: ?>>) -> memref<1x4xf32, strided<[?, ?], offset: ?>> { 55 %0 = memref.cast %arg : memref<1x1x?xf32, strided<[35, 7, 1], offset: ?>> to memref<1x1x?xf32, strided<[?, ?, ?], offset: ?>> 56 %1 = memref.subview %0[0, 0, 0] [1, 1, 4] [1, 1, 1] : memref<1x1x?xf32, strided<[?, ?, ?], offset: ?>> to memref<1x4xf32, strided<[?, ?], offset: ?>> 57 return %1 : memref<1x4xf32, strided<[?, ?], offset: ?>> 58} 59 60// ----- 61 62// CHECK-LABEL: func @subview_of_static_full_size 63// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> 64// CHECK-NOT: memref.subview 65// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8> 66func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> { 67 %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8> 68 return %0 : memref<4x6x16x32xi8> 69} 70 71// ----- 72 73// CHECK-LABEL: func @negative_subview_of_static_full_size 74// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>> 75// CHECK-SAME: %[[IDX:.+]]: index 76// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][%[[IDX]], 0] [16, 4] [1, 1] 77// CHECK-SAME: to memref<16x4xf32, strided<[4, 1], offset: ?>> 78// CHECK: return %[[S]] : memref<16x4xf32, strided<[4, 1], offset: ?>> 79func.func @negative_subview_of_static_full_size(%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { 80 %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>> 81 return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>> 82} 83 84// ----- 85 86func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index, 87 %arg2 : index) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 88{ 89 %c0 = arith.constant 0 : index 90 %c1 = arith.constant 1 : index 91 %c4 = arith.constant 4 : index 92 %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 93 return %0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 94} 95// CHECK-LABEL: func @subview_canonicalize 96// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x?xf32> 97// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] 98// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] 99// CHECK-SAME: : memref<?x?x?xf32> to memref<4x1x?xf32 100// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] 101// CHECK: return %[[RESULT]] 102 103// ----- 104 105func.func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index, 106 %arg2 : index) -> memref<?x?xf32, strided<[?, ?], offset: ?>> 107{ 108 %c0 = arith.constant 0 : index 109 %c1 = arith.constant 1 : index 110 %c4 = arith.constant 4 : index 111 %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>> 112 return %0 : memref<?x?xf32, strided<[?, ?], offset: ?>> 113} 114// CHECK-LABEL: func @rank_reducing_subview_canonicalize 115// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x?xf32> 116// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] 117// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] 118// CHECK-SAME: : memref<?x?x?xf32> to memref<4x?xf32 119// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] 120// CHECK: return %[[RESULT]] 121 122// ----- 123 124func.func @multiple_reducing_dims(%arg0 : memref<1x384x384xf32>, 125 %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, strided<[1], offset: ?>> 126{ 127 %c1 = arith.constant 1 : index 128 %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>> 129 %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, strided<[384, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> 130 return %1 : memref<?xf32, strided<[1], offset: ?>> 131} 132// CHECK: func @multiple_reducing_dims 133// CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1] 134// CHECK-SAME: : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>> 135// CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] 136// CHECK-SAME: : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> 137 138// ----- 139 140func.func @multiple_reducing_dims_dynamic(%arg0 : memref<?x?x?xf32>, 141 %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, strided<[1], offset: ?>> 142{ 143 %c1 = arith.constant 1 : index 144 %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>> 145 %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> 146 return %1 : memref<?xf32, strided<[1], offset: ?>> 147} 148// CHECK: func @multiple_reducing_dims_dynamic 149// CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1] 150// CHECK-SAME: : memref<?x?x?xf32> to memref<1x?xf32, strided<[?, 1], offset: ?>> 151// CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] 152// CHECK-SAME: : memref<1x?xf32, strided<[?, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> 153 154// ----- 155 156func.func @multiple_reducing_dims_all_dynamic(%arg0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, 157 %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, strided<[?], offset: ?>> 158{ 159 %c1 = arith.constant 1 : index 160 %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] 161 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>> 162 %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[?], offset: ?>> 163 return %1 : memref<?xf32, strided<[?], offset: ?>> 164} 165// CHECK: func @multiple_reducing_dims_all_dynamic 166// CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1] 167// CHECK-SAME: : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<1x?xf32, strided<[?, ?], offset: ?>> 168// CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] 169// CHECK-SAME: : memref<1x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[?], offset: ?>> 170 171// ----- 172 173func.func @subview_negative_stride1(%arg0 : memref<?xf32>) -> memref<?xf32, strided<[?], offset: ?>> 174{ 175 %c0 = arith.constant 0 : index 176 %c1 = arith.constant -1 : index 177 %1 = memref.dim %arg0, %c0 : memref<?xf32> 178 %2 = arith.addi %1, %c1 : index 179 %3 = memref.subview %arg0[%2] [%1] [%c1] : memref<?xf32> to memref<?xf32, strided<[?], offset: ?>> 180 return %3 : memref<?xf32, strided<[?], offset: ?>> 181} 182// CHECK: func @subview_negative_stride1 183// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>) 184// CHECK: %[[C1:.*]] = arith.constant 0 185// CHECK: %[[C2:.*]] = arith.constant -1 186// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<?xf32> 187// CHECK: %[[DIM2:.*]] = arith.addi %[[DIM1]], %[[C2]] : index 188// CHECK: %[[RES1:.*]] = memref.subview %[[ARG0]][%[[DIM2]]] [%[[DIM1]]] [-1] : memref<?xf32> to memref<?xf32, strided<[-1], offset: ?>> 189// CHECK: %[[RES2:.*]] = memref.cast %[[RES1]] : memref<?xf32, strided<[-1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>> 190// CHECK: return %[[RES2]] : memref<?xf32, strided<[?], offset: ?>> 191 192// ----- 193 194func.func @subview_negative_stride2(%arg0 : memref<7xf32>) -> memref<?xf32, strided<[?], offset: ?>> 195{ 196 %c0 = arith.constant 0 : index 197 %c1 = arith.constant -1 : index 198 %1 = memref.dim %arg0, %c0 : memref<7xf32> 199 %2 = arith.addi %1, %c1 : index 200 %3 = memref.subview %arg0[%2] [%1] [%c1] : memref<7xf32> to memref<?xf32, strided<[?], offset: ?>> 201 return %3 : memref<?xf32, strided<[?], offset: ?>> 202} 203// CHECK: func @subview_negative_stride2 204// CHECK-SAME: (%[[ARG0:.*]]: memref<7xf32>) 205// CHECK: %[[RES1:.*]] = memref.subview %[[ARG0]][6] [7] [-1] : memref<7xf32> to memref<7xf32, strided<[-1], offset: 6>> 206// CHECK: %[[RES2:.*]] = memref.cast %[[RES1]] : memref<7xf32, strided<[-1], offset: 6>> to memref<?xf32, strided<[?], offset: ?>> 207// CHECK: return %[[RES2]] : memref<?xf32, strided<[?], offset: ?>> 208 209// ----- 210 211// CHECK-LABEL: func @dim_of_sized_view 212// CHECK-SAME: %{{[a-z0-9A-Z_]+}}: memref<?xi8> 213// CHECK-SAME: %[[SIZE:.[a-z0-9A-Z_]+]]: index 214// CHECK: return %[[SIZE]] : index 215func.func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index { 216 %c0 = arith.constant 0 : index 217 %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref<?xi8> to memref<?xi8> 218 %1 = memref.dim %0, %c0 : memref<?xi8> 219 return %1 : index 220} 221 222// ----- 223 224// CHECK-LABEL: func @no_fold_subview_negative_size 225// CHECK: %[[SUBVIEW:.+]] = memref.subview 226// CHECK: return %[[SUBVIEW]] 227func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<?x256xf32, strided<[1024, 1], offset: 2304>> { 228 %cst = arith.constant -13 : index 229 %0 = memref.subview %input[2, 256] [%cst, 256] [1, 1] : memref<4x1024xf32> to memref<?x256xf32, strided<[1024, 1], offset: 2304>> 230 return %0 : memref<?x256xf32, strided<[1024, 1], offset: 2304>> 231} 232 233// ----- 234 235// CHECK-LABEL: func @no_fold_subview_zero_stride 236// CHECK: %[[SUBVIEW:.+]] = memref.subview 237// CHECK: return %[[SUBVIEW]] 238func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> { 239 %c0 = arith.constant 0 : index 240 %c1 = arith.constant 1 : index 241 %1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>> 242 return %1 : memref<1xf32, strided<[?], offset: 1>> 243} 244 245// ----- 246 247// CHECK-LABEL: func @no_fold_of_store 248// CHECK: %[[cst:.+]] = memref.cast %arg 249// CHECK: memref.store %[[cst]] 250func.func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) { 251 %0 = memref.cast %arg : memref<32xi8> to memref<?xi8> 252 memref.store %0, %holder[] : memref<memref<?xi8>> 253 return 254} 255 256// ----- 257 258// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size 259// CHECK-LABEL: func @dim_of_alloca( 260// CHECK-SAME: %[[SIZE:[0-9a-z]+]]: index 261// CHECK-NEXT: return %[[SIZE]] : index 262func.func @dim_of_alloca(%size: index) -> index { 263 %0 = memref.alloca(%size) : memref<?xindex> 264 %c0 = arith.constant 0 : index 265 %1 = memref.dim %0, %c0 : memref<?xindex> 266 return %1 : index 267} 268 269// ----- 270 271// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v) 272// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size( 273// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32> 274// CHECK-NEXT: %[[RANK:.*]] = memref.rank %[[MEM]] : memref<*xf32> 275// CHECK-NEXT: return %[[RANK]] : index 276func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index { 277 %0 = memref.rank %arg0 : memref<*xf32> 278 %1 = memref.alloca(%0) : memref<?xindex> 279 %c0 = arith.constant 0 : index 280 %2 = memref.dim %1, %c0 : memref<?xindex> 281 return %2 : index 282} 283 284// ----- 285 286// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] 287// CHECK-LABEL: func @dim_of_memref_reshape( 288// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, 289// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex> 290// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 291// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] 292// CHECK-NEXT: memref.store 293// CHECK-NOT: memref.dim 294// CHECK: return %[[DIM]] : index 295func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>) 296 -> index { 297 %c3 = arith.constant 3 : index 298 %0 = memref.reshape %arg0(%arg1) 299 : (memref<*xf32>, memref<?xindex>) -> memref<*xf32> 300 // Update the shape to test that he load ends up in the right place. 301 memref.store %c3, %arg1[%c3] : memref<?xindex> 302 %1 = memref.dim %0, %c3 : memref<*xf32> 303 return %1 : index 304} 305 306// ----- 307 308// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] 309// CHECK-LABEL: func @dim_of_memref_reshape_i32( 310// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, 311// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32> 312// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 313// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] 314// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]] 315// CHECK-NOT: memref.dim 316// CHECK: return %[[CAST]] : index 317func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>) 318 -> index { 319 %c3 = arith.constant 3 : index 320 %0 = memref.reshape %arg0(%arg1) 321 : (memref<*xf32>, memref<?xi32>) -> memref<*xf32> 322 %1 = memref.dim %0, %c3 : memref<*xf32> 323 return %1 : index 324} 325 326// ----- 327 328// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] 329// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index( 330// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, 331// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>, 332// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index 333// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] 334// CHECK-NOT: memref.dim 335// CHECK: return %[[DIM]] : index 336func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index { 337 %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32> 338 %dim = memref.dim %reshape, %arg2 : memref<*xf32> 339 return %dim : index 340} 341 342// ----- 343 344// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx] 345// CHECK-LABEL: func @dim_of_memref_reshape_for( 346// CHECK: memref.reshape 347// CHECK: memref.dim 348// CHECK-NOT: memref.load 349func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index { 350 %c0 = arith.constant 0 : index 351 %c1 = arith.constant 1 : index 352 %c4 = arith.constant 4 : index 353 354 %0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32> 355 356 %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) { 357 %2 = memref.dim %0, %arg2 : memref<*xf32> 358 %3 = arith.muli %arg3, %2 : index 359 scf.yield %3 : index 360 } 361 return %1 : index 362} 363 364// ----- 365 366// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx] 367// CHECK-LABEL: func @dim_of_memref_reshape_undominated( 368// CHECK: memref.reshape 369// CHECK: memref.dim 370// CHECK-NOT: memref.load 371func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index { 372 %c4 = arith.constant 4 : index 373 %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32> 374 %0 = arith.muli %arg2, %c4 : index 375 %dim = memref.dim %reshape, %0 : memref<*xf32> 376 return %dim : index 377 } 378 379// ----- 380 381// CHECK-LABEL: func @alloc_const_fold 382func.func @alloc_const_fold() -> memref<?xf32> { 383 // CHECK-NEXT: memref.alloc() : memref<4xf32> 384 %c4 = arith.constant 4 : index 385 %a = memref.alloc(%c4) : memref<?xf32> 386 387 // CHECK-NEXT: memref.cast %{{.*}} : memref<4xf32> to memref<?xf32> 388 // CHECK-NEXT: return %{{.*}} : memref<?xf32> 389 return %a : memref<?xf32> 390} 391 392// ----- 393 394// CHECK-LABEL: func @alloc_alignment_const_fold 395func.func @alloc_alignment_const_fold() -> memref<?xf32> { 396 // CHECK-NEXT: memref.alloc() {alignment = 4096 : i64} : memref<4xf32> 397 %c4 = arith.constant 4 : index 398 %a = memref.alloc(%c4) {alignment = 4096 : i64} : memref<?xf32> 399 400 // CHECK-NEXT: memref.cast %{{.*}} : memref<4xf32> to memref<?xf32> 401 // CHECK-NEXT: return %{{.*}} : memref<?xf32> 402 return %a : memref<?xf32> 403} 404 405// ----- 406 407// CHECK-LABEL: func @alloc_const_fold_with_symbols1( 408// CHECK: %[[c1:.+]] = arith.constant 1 : index 409// CHECK: %[[mem1:.+]] = memref.alloc({{.*}})[%[[c1]], %[[c1]]] : memref<?xi32, strided{{.*}}> 410// CHECK: return %[[mem1]] : memref<?xi32, strided{{.*}}> 411func.func @alloc_const_fold_with_symbols1(%arg0 : index) -> memref<?xi32, strided<[?], offset: ?>> { 412 %c1 = arith.constant 1 : index 413 %0 = memref.alloc(%arg0)[%c1, %c1] : memref<?xi32, strided<[?], offset: ?>> 414 return %0 : memref<?xi32, strided<[?], offset: ?>> 415} 416 417// ----- 418 419// CHECK-LABEL: func @alloc_const_fold_with_symbols2( 420// CHECK: %[[c1:.+]] = arith.constant 1 : index 421// CHECK: %[[mem1:.+]] = memref.alloc()[%[[c1]], %[[c1]]] : memref<1xi32, strided{{.*}}> 422// CHECK: %[[mem2:.+]] = memref.cast %[[mem1]] : memref<1xi32, strided{{.*}}> to memref<?xi32, strided{{.*}}> 423// CHECK: return %[[mem2]] : memref<?xi32, strided{{.*}}> 424func.func @alloc_const_fold_with_symbols2() -> memref<?xi32, strided<[?], offset: ?>> { 425 %c1 = arith.constant 1 : index 426 %0 = memref.alloc(%c1)[%c1, %c1] : memref<?xi32, strided<[?], offset: ?>> 427 return %0 : memref<?xi32, strided<[?], offset: ?>> 428} 429 430// ----- 431// CHECK-LABEL: func @allocator 432// CHECK: %[[alloc:.+]] = memref.alloc 433// CHECK: memref.store %[[alloc:.+]], %arg0 434func.func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index) { 435 %0 = memref.alloc(%arg1) : memref<?xi32> 436 memref.store %0, %arg0[] : memref<memref<?xi32>> 437 return 438} 439 440// ----- 441 442func.func @compose_collapse_of_collapse_zero_dim(%arg0 : memref<1x1x1xf32>) 443 -> memref<f32> { 444 %0 = memref.collapse_shape %arg0 [[0, 1, 2]] 445 : memref<1x1x1xf32> into memref<1xf32> 446 %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref<f32> 447 return %1 : memref<f32> 448} 449// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim 450// CHECK: memref.collapse_shape %{{.*}} [] 451// CHECK-SAME: memref<1x1x1xf32> into memref<f32> 452 453// ----- 454 455func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>) 456 -> memref<?x?xf32> { 457 %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] 458 : memref<?x?x?x?x?xf32> into memref<?x?x?xf32> 459 %1 = memref.collapse_shape %0 [[0, 1], [2]] 460 : memref<?x?x?xf32> into memref<?x?xf32> 461 return %1 : memref<?x?xf32> 462} 463// CHECK-LABEL: func @compose_collapse_of_collapse 464// CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] 465// CHECK-NOT: memref.collapse_shape 466 467// ----- 468 469func.func @do_not_compose_collapse_of_expand_non_identity_layout( 470 %arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index) 471 -> memref<?xf32, strided<[?], offset: 0>> { 472 %1 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1] : 473 memref<?x?xf32, strided<[?, 1], offset: 0>> into 474 memref<?x4x?xf32, strided<[?, ?, 1], offset: 0>> 475 %2 = memref.collapse_shape %1 [[0, 1, 2]] : 476 memref<?x4x?xf32, strided<[?, ?, 1], offset: 0>> into 477 memref<?xf32, strided<[?], offset: 0>> 478 return %2 : memref<?xf32, strided<[?], offset: 0>> 479} 480// CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout 481// CHECK: expand 482// CHECK: collapse 483 484// ----- 485 486func.func @compose_expand_of_expand(%arg0 : memref<?x?xf32>, %sz0: index, %sz1: index, %sz2: index, %sz3: index) 487 -> memref<?x6x4x5x?xf32> { 488 %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1] 489 : memref<?x?xf32> into memref<?x4x?xf32> 490 %1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%sz2, 6, 4, 5, %sz3] : memref<?x4x?xf32> into memref<?x6x4x5x?xf32> 491 return %1 : memref<?x6x4x5x?xf32> 492} 493// CHECK-LABEL: func @compose_expand_of_expand 494// CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%{{.*}}, 6, 4, 5, %{{.*}}] 495// CHECK-NOT: memref.expand_shape 496 497// ----- 498 499func.func @compose_expand_of_expand_of_zero_dim(%arg0 : memref<f32>) 500 -> memref<1x1x1xf32> { 501 %0 = memref.expand_shape %arg0 [] output_shape [1] : memref<f32> into memref<1xf32> 502 %1 = memref.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1] 503 : memref<1xf32> into memref<1x1x1xf32> 504 return %1 : memref<1x1x1xf32> 505} 506// CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim 507// CHECK: memref.expand_shape %{{.*}} [] output_shape [1, 1, 1] 508// CHECK-SAME: memref<f32> into memref<1x1x1xf32> 509 510// ----- 511 512func.func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> { 513 %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4] 514 : memref<12x4xf32> into memref<3x4x4xf32> 515 %1 = memref.collapse_shape %0 [[0, 1], [2]] 516 : memref<3x4x4xf32> into memref<12x4xf32> 517 return %1 : memref<12x4xf32> 518} 519// CHECK-LABEL: func @fold_collapse_of_expand 520// CHECK-NOT: linalg.{{.*}}_shape 521 522// ----- 523 524func.func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>, %sz0: index, %sz1: index) 525 -> memref<?x?xf32> { 526 %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1] 527 : memref<?x?xf32> into memref<?x4x?xf32> 528 %1 = memref.collapse_shape %0 [[0, 1], [2]] 529 : memref<?x4x?xf32> into memref<?x?xf32> 530 return %1 : memref<?x?xf32> 531} 532// CHECK-LABEL: @fold_collapse_collapse_of_expand 533// CHECK-NOT: linalg.{{.*}}_shape 534 535// ----- 536 537func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32> { 538 %0 = memref.cast %arg0 : memref<?x?xf32> to memref<8x4xf32> 539 %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [2, 4, 4] 540 : memref<8x4xf32> into memref<2x4x4xf32> 541 return %1 : memref<2x4x4xf32> 542} 543 544// CHECK-LABEL: @fold_memref_expand_cast 545// CHECK: memref.expand_shape 546 547// ----- 548 549// CHECK-LABEL: func @collapse_after_memref_cast_type_change( 550// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> { 551// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]] 552// CHECK-SAME: {{\[\[}}0], [1, 2, 3]] : memref<?x512x1x1xf32> into memref<?x512xf32> 553// CHECK: %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] : 554// CHECK-SAME: memref<?x512xf32> to memref<?x?xf32> 555// CHECK: return %[[DYNAMIC]] : memref<?x?xf32> 556// CHECK: } 557func.func @collapse_after_memref_cast_type_change(%arg0 : memref<?x512x1x1xf32>) -> memref<?x?xf32> { 558 %dynamic = memref.cast %arg0: memref<?x512x1x1xf32> to memref<?x?x?x?xf32> 559 %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32> 560 return %collapsed : memref<?x?xf32> 561} 562 563// ----- 564 565// CHECK-LABEL: func @collapse_after_memref_cast( 566// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x?xf32>) -> memref<?x?xf32> { 567// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]] 568// CHECK-SAME: {{\[\[}}0], [1, 2, 3]] : memref<?x512x1x?xf32> into memref<?x?xf32> 569// CHECK: return %[[COLLAPSED]] : memref<?x?xf32> 570func.func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf32> { 571 %dynamic = memref.cast %arg0: memref<?x512x1x?xf32> to memref<?x?x?x?xf32> 572 %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32> 573 return %collapsed : memref<?x?xf32> 574} 575 576// ----- 577 578// CHECK-LABEL: func @collapse_after_memref_cast_type_change_dynamic( 579// CHECK-SAME: %[[INPUT:.*]]: memref<1x1x1x?xi64>) -> memref<?x?xi64> { 580// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]] 581// CHECK-SAME: {{\[\[}}0, 1, 2], [3]] : memref<1x1x1x?xi64> into memref<1x?xi64> 582// CHECK: %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] : 583// CHECK-SAME: memref<1x?xi64> to memref<?x?xi64> 584// CHECK: return %[[DYNAMIC]] : memref<?x?xi64> 585func.func @collapse_after_memref_cast_type_change_dynamic(%arg0: memref<1x1x1x?xi64>) -> memref<?x?xi64> { 586 %casted = memref.cast %arg0 : memref<1x1x1x?xi64> to memref<1x1x?x?xi64> 587 %collapsed = memref.collapse_shape %casted [[0, 1, 2], [3]] : memref<1x1x?x?xi64> into memref<?x?xi64> 588 return %collapsed : memref<?x?xi64> 589} 590 591// ----- 592 593func.func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index) 594 -> memref<1x4x1xf32, strided<[35, 7, 1], offset: ?>> { 595 %c0 = arith.constant 0 : index 596 %c5 = arith.constant 5 : index 597 %c4 = arith.constant 4 : index 598 %c2 = arith.constant 2 : index 599 %c1 = arith.constant 1 : index 600 %0 = memref.subview %arg0[%arg1, %arg1, %arg1, 0] [%c1, %c4, %c1, 1] [1, 1, 1, 1] 601 : memref<2x5x7x1xf32> to memref<?x?x?xf32, strided<[35, 7, 1], offset: ?>> 602 %1 = memref.cast %0 603 : memref<?x?x?xf32, strided<[35, 7, 1], offset: ?>> to 604 memref<1x4x1xf32, strided<[35, 7, 1], offset: ?>> 605 return %1 : memref<1x4x1xf32, strided<[35, 7, 1], offset: ?>> 606} 607 608// CHECK-LABEL: func @reduced_memref 609// CHECK: %[[RESULT:.+]] = memref.subview 610// CHECK-SAME: memref<2x5x7x1xf32> to memref<1x4x1xf32, strided{{.+}}> 611// CHECK: return %[[RESULT]] 612 613// ----- 614 615// CHECK-LABEL: func @fold_rank_memref 616func.func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) { 617 // Fold a rank into a constant 618 // CHECK-NEXT: [[C2:%.+]] = arith.constant 2 : index 619 %rank_0 = memref.rank %arg0 : memref<?x?xf32> 620 621 // CHECK-NEXT: return [[C2]] 622 return %rank_0 : index 623} 624 625// ----- 626 627func.func @fold_no_op_subview(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, strided<[42, 1]>> { 628 %0 = memref.subview %arg0[0, 0] [20, 42] [1, 1] : memref<20x42xf32> to memref<20x42xf32, strided<[42, 1]>> 629 return %0 : memref<20x42xf32, strided<[42, 1]>> 630} 631// CHECK-LABEL: func @fold_no_op_subview( 632// CHECK: %[[ARG0:.+]]: memref<20x42xf32>) 633// CHECK: %[[CAST:.+]] = memref.cast %[[ARG0]] 634// CHECK: return %[[CAST]] 635 636// ----- 637 638func.func @no_fold_subview_with_non_zero_offset(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, strided<[42, 1], offset: 1>> { 639 %0 = memref.subview %arg0[0, 1] [20, 42] [1, 1] : memref<20x42xf32> to memref<20x42xf32, strided<[42, 1], offset: 1>> 640 return %0 : memref<20x42xf32, strided<[42, 1], offset: 1>> 641} 642// CHECK-LABEL: func @no_fold_subview_with_non_zero_offset( 643// CHECK: %[[SUBVIEW:.+]] = memref.subview 644// CHECK: return %[[SUBVIEW]] 645 646// ----- 647 648func.func @no_fold_subview_with_non_unit_stride(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, strided<[42, 2]>> { 649 %0 = memref.subview %arg0[0, 0] [20, 42] [1, 2] : memref<20x42xf32> to memref<20x42xf32, strided<[42, 2]>> 650 return %0 : memref<20x42xf32, strided<[42, 2]>> 651} 652// CHECK-LABEL: func @no_fold_subview_with_non_unit_stride( 653// CHECK: %[[SUBVIEW:.+]] = memref.subview 654// CHECK: return %[[SUBVIEW]] 655 656// ----- 657 658func.func @no_fold_dynamic_no_op_subview(%arg0 : memref<?x?xf32>) -> memref<?x?xf32, strided<[?, 1]>> { 659 %c0 = arith.constant 0 : index 660 %c1 = arith.constant 1 : index 661 %0 = memref.dim %arg0, %c0 : memref<?x?xf32> 662 %1 = memref.dim %arg0, %c1 : memref<?x?xf32> 663 %2 = memref.subview %arg0[0, 0] [%0, %1] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>> 664 return %2 : memref<?x?xf32, strided<[?, 1]>> 665} 666// CHECK-LABEL: func @no_fold_dynamic_no_op_subview( 667// CHECK: %[[SUBVIEW:.+]] = memref.subview 668// CHECK: return %[[SUBVIEW]] 669 670// ----- 671 672func.func @atomicrmw_cast_fold(%arg0 : f32, %arg1 : memref<4xf32>, %c : index) { 673 %v = memref.cast %arg1 : memref<4xf32> to memref<?xf32> 674 %a = memref.atomic_rmw addf %arg0, %v[%c] : (f32, memref<?xf32>) -> f32 675 return 676} 677 678// CHECK-LABEL: func @atomicrmw_cast_fold 679// CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32 680 681// ----- 682 683func.func @copy_of_cast(%m1: memref<?xf32>, %m2: memref<*xf32>) { 684 %casted1 = memref.cast %m1 : memref<?xf32> to memref<?xf32, strided<[?], offset: ?>> 685 %casted2 = memref.cast %m2 : memref<*xf32> to memref<?xf32, strided<[?], offset: ?>> 686 memref.copy %casted1, %casted2 : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[?], offset: ?>> 687 return 688} 689 690// CHECK-LABEL: func @copy_of_cast( 691// CHECK-SAME: %[[m1:.*]]: memref<?xf32>, %[[m2:.*]]: memref<*xf32> 692// CHECK: %[[casted2:.*]] = memref.cast %[[m2]] 693// CHECK: memref.copy %[[m1]], %[[casted2]] 694 695// ----- 696 697func.func @self_copy(%m1: memref<?xf32>) { 698 memref.copy %m1, %m1 : memref<?xf32> to memref<?xf32> 699 return 700} 701 702// CHECK-LABEL: func @self_copy 703// CHECK-NEXT: return 704 705// ----- 706 707// CHECK-LABEL: func @empty_copy 708// CHECK-NEXT: return 709func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) { 710 memref.copy %m1, %m2 : memref<0x10xf32> to memref<?x10xf32> 711 memref.copy %m2, %m1 : memref<?x10xf32> to memref<0x10xf32> 712 return 713} 714 715// ----- 716 717func.func @scopeMerge() { 718 memref.alloca_scope { 719 %cnt = "test.count"() : () -> index 720 %a = memref.alloca(%cnt) : memref<?xi64> 721 "test.use"(%a) : (memref<?xi64>) -> () 722 } 723 return 724} 725// CHECK: func @scopeMerge() { 726// CHECK-NOT: alloca_scope 727// CHECK: %[[cnt:.+]] = "test.count"() : () -> index 728// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64> 729// CHECK: "test.use"(%[[alloc]]) : (memref<?xi64>) -> () 730// CHECK: return 731 732func.func @scopeMerge2() { 733 "test.region"() ({ 734 memref.alloca_scope { 735 %cnt = "test.count"() : () -> index 736 %a = memref.alloca(%cnt) : memref<?xi64> 737 "test.use"(%a) : (memref<?xi64>) -> () 738 } 739 "test.terminator"() : () -> () 740 }) : () -> () 741 return 742} 743 744// CHECK: func @scopeMerge2() { 745// CHECK: "test.region"() ({ 746// CHECK: memref.alloca_scope { 747// CHECK: %[[cnt:.+]] = "test.count"() : () -> index 748// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64> 749// CHECK: "test.use"(%[[alloc]]) : (memref<?xi64>) -> () 750// CHECK: } 751// CHECK: "test.terminator"() : () -> () 752// CHECK: }) : () -> () 753// CHECK: return 754// CHECK: } 755 756func.func @scopeMerge3() { 757 %cnt = "test.count"() : () -> index 758 "test.region"() ({ 759 memref.alloca_scope { 760 %a = memref.alloca(%cnt) : memref<?xi64> 761 "test.use"(%a) : (memref<?xi64>) -> () 762 } 763 "test.terminator"() : () -> () 764 }) : () -> () 765 return 766} 767 768// CHECK: func @scopeMerge3() { 769// CHECK: %[[cnt:.+]] = "test.count"() : () -> index 770// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64> 771// CHECK: "test.region"() ({ 772// CHECK: memref.alloca_scope { 773// CHECK: "test.use"(%[[alloc]]) : (memref<?xi64>) -> () 774// CHECK: } 775// CHECK: "test.terminator"() : () -> () 776// CHECK: }) : () -> () 777// CHECK: return 778// CHECK: } 779 780func.func @scopeMerge4() { 781 %cnt = "test.count"() : () -> index 782 "test.region"() ({ 783 memref.alloca_scope { 784 %a = memref.alloca(%cnt) : memref<?xi64> 785 "test.use"(%a) : (memref<?xi64>) -> () 786 } 787 "test.op"() : () -> () 788 "test.terminator"() : () -> () 789 }) : () -> () 790 return 791} 792 793// CHECK: func @scopeMerge4() { 794// CHECK: %[[cnt:.+]] = "test.count"() : () -> index 795// CHECK: "test.region"() ({ 796// CHECK: memref.alloca_scope { 797// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64> 798// CHECK: "test.use"(%[[alloc]]) : (memref<?xi64>) -> () 799// CHECK: } 800// CHECK: "test.op"() : () -> () 801// CHECK: "test.terminator"() : () -> () 802// CHECK: }) : () -> () 803// CHECK: return 804// CHECK: } 805 806func.func @scopeMerge5() { 807 "test.region"() ({ 808 memref.alloca_scope { 809 affine.parallel (%arg) = (0) to (64) { 810 %a = memref.alloca(%arg) : memref<?xi64> 811 "test.use"(%a) : (memref<?xi64>) -> () 812 } 813 } 814 "test.op"() : () -> () 815 "test.terminator"() : () -> () 816 }) : () -> () 817 return 818} 819 820// CHECK: func @scopeMerge5() { 821// CHECK: "test.region"() ({ 822// CHECK: affine.parallel (%[[cnt:.+]]) = (0) to (64) { 823// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64> 824// CHECK: "test.use"(%[[alloc]]) : (memref<?xi64>) -> () 825// CHECK: } 826// CHECK: "test.op"() : () -> () 827// CHECK: "test.terminator"() : () -> () 828// CHECK: }) : () -> () 829// CHECK: return 830// CHECK: } 831 832func.func @scopeInline(%arg : memref<index>) { 833 %cnt = "test.count"() : () -> index 834 "test.region"() ({ 835 memref.alloca_scope { 836 memref.store %cnt, %arg[] : memref<index> 837 } 838 "test.terminator"() : () -> () 839 }) : () -> () 840 return 841} 842 843// CHECK: func @scopeInline 844// CHECK-NOT: memref.alloca_scope 845 846// ----- 847 848// CHECK-LABEL: func @reinterpret_noop 849// CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>) 850// CHECK-NEXT: return %[[ARG]] 851func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> { 852 %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [2, 3, 4], strides: [12, 4, 1] : memref<2x3x4xf32> to memref<2x3x4xf32> 853 return %0 : memref<2x3x4xf32> 854} 855 856// ----- 857 858// CHECK-LABEL: func @reinterpret_of_reinterpret 859// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) 860// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] 861// CHECK: return %[[RES]] 862func.func @reinterpret_of_reinterpret(%arg : memref<?xi8>, %size1: index, %size2: index) -> memref<?xi8> { 863 %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size1], strides: [1] : memref<?xi8> to memref<?xi8> 864 %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref<?xi8> to memref<?xi8> 865 return %1 : memref<?xi8> 866} 867 868// ----- 869 870// CHECK-LABEL: func @reinterpret_of_cast 871// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE:.*]]: index) 872// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE]]], strides: [1] 873// CHECK: return %[[RES]] 874func.func @reinterpret_of_cast(%arg : memref<?xi8>, %size: index) -> memref<?xi8> { 875 %0 = memref.cast %arg : memref<?xi8> to memref<5xi8> 876 %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size], strides: [1] : memref<5xi8> to memref<?xi8> 877 return %1 : memref<?xi8> 878} 879 880// ----- 881 882// CHECK-LABEL: func @reinterpret_of_subview 883// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) 884// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] 885// CHECK: return %[[RES]] 886func.func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: index) -> memref<?xi8> { 887 %0 = memref.subview %arg[0] [%size1] [1] : memref<?xi8> to memref<?xi8> 888 %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref<?xi8> to memref<?xi8> 889 return %1 : memref<?xi8> 890} 891 892// ----- 893 894// Check that a reinterpret cast of an equivalent extract strided metadata 895// is canonicalized to a plain cast when the destination type is different 896// than the type of the original memref. 897// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach 898// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) 899// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32, 900// CHECK: return %[[CAST]] 901func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { 902 %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index 903 %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> 904 return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>> 905} 906 907// ----- 908 909// Similar to reinterpret_of_extract_strided_metadata_w_type_mistach except that 910// we check that the match happen when the static information has been folded. 911// E.g., in this case, we know that size of dim 0 is 8 and size of dim 1 is 2. 912// So even if we don't use the values sizes#0, sizes#1, as long as they have the 913// same constant value, the match is valid. 914// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants 915// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) 916// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32, 917// CHECK: return %[[CAST]] 918func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { 919 %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index 920 %c8 = arith.constant 8: index 921 %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> 922 return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>> 923} 924// ----- 925 926// Check that a reinterpret cast of an equivalent extract strided metadata 927// is completely removed when the original memref has the same type. 928// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type 929// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32 930// CHECK: return %[[ARG]] 931func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?xf32, strided<[?,?], offset: ?>>) -> memref<?x?xf32, strided<[?,?], offset: ?>> { 932 %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<?x?xf32, strided<[?,?], offset: ?>> -> memref<f32>, index, index, index, index, index 933 %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?,?], offset:?>> 934 return %m2 : memref<?x?xf32, strided<[?,?], offset:?>> 935} 936 937// ----- 938 939// Check that we don't simplify reinterpret cast of extract strided metadata 940// when the strides don't match. 941// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride 942// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) 943// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 944// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 945// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] 946// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]] 947// CHECK: return %[[RES]] 948func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> { 949 %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index 950 %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 951 return %m2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 952} 953// ----- 954 955// Check that we don't simplify reinterpret cast of extract strided metadata 956// when the offset doesn't match. 957// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset 958// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) 959// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index 960// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 961// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 962// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] 963// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]] 964// CHECK: return %[[RES]] 965func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { 966 %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index 967 %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> 968 return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>> 969} 970 971// ----- 972 973func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>, 974 %arg1 : index) -> memref<?xf32, strided<[?], offset: ?>> { 975 %c0 = arith.constant 0 : index 976 %c1 = arith.constant 1 : index 977 %0 = memref.subview %arg0[%c0, %c0] [1, %arg1] [%c1, %c1] : memref<8x?xf32> to memref<?xf32, strided<[?], offset: ?>> 978 return %0 : memref<?xf32, strided<[?], offset: ?>> 979} 980// CHECK: func @canonicalize_rank_reduced_subview 981// CHECK-SAME: %[[ARG0:.+]]: memref<8x?xf32> 982// CHECK-SAME: %[[ARG1:.+]]: index 983// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, 0] [1, %[[ARG1]]] [1, 1] 984// CHECK-SAME: memref<8x?xf32> to memref<?xf32, strided<[1]>> 985 986// ----- 987 988// CHECK-LABEL: func @memref_realloc_dead 989// CHECK-SAME: %[[SRC:[0-9a-z]+]]: memref<2xf32> 990// CHECK-NOT: memref.realloc 991// CHECK: return %[[SRC]] 992func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{ 993 %0 = memref.realloc %src : memref<2xf32> to memref<4xf32> 994 %i2 = arith.constant 2 : index 995 memref.store %v, %0[%i2] : memref<4xf32> 996 return %src : memref<2xf32> 997} 998 999// ----- 1000 1001// CHECK-LABEL: func @collapse_expand_fold_to_cast( 1002// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[1]>, 3> 1003// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<?xf32, strided<[1]>, 3> to memref<?xf32, 3 1004// CHECK: return %[[casted]] 1005func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0: index) 1006 -> (memref<?xf32, 3>) 1007{ 1008 %0 = memref.expand_shape %m [[0, 1]] output_shape [1, %sz0] 1009 : memref<?xf32, strided<[1]>, 3> into memref<1x?xf32, 3> 1010 %1 = memref.collapse_shape %0 [[0, 1]] 1011 : memref<1x?xf32, 3> into memref<?xf32, 3> 1012 return %1 : memref<?xf32, 3> 1013} 1014 1015// ----- 1016 1017// CHECK-LABEL: func @fold_trivial_subviews( 1018// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>> 1019// CHECK: %[[subview:.*]] = memref.subview %[[m]][5] 1020// CHECK: return %[[subview]] 1021func.func @fold_trivial_subviews(%m: memref<?xf32, strided<[?], offset: ?>>, 1022 %sz: index) 1023 -> memref<?xf32, strided<[?], offset: ?>> 1024{ 1025 %0 = memref.subview %m[5] [%sz] [1] 1026 : memref<?xf32, strided<[?], offset: ?>> 1027 to memref<?xf32, strided<[?], offset: ?>> 1028 %1 = memref.subview %0[0] [%sz] [1] 1029 : memref<?xf32, strided<[?], offset: ?>> 1030 to memref<?xf32, strided<[?], offset: ?>> 1031 return %1 : memref<?xf32, strided<[?], offset: ?>> 1032} 1033 1034// ----- 1035 1036// CHECK-LABEL: func @load_store_nontemporal( 1037func.func @load_store_nontemporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, %output : memref<32xf32, affine_map<(d0) -> (d0)>>) { 1038 %1 = arith.constant 7 : index 1039 // CHECK: memref.load %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32> 1040 %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> 1041 // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32> 1042 memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> 1043 func.return 1044} 1045 1046// ----- 1047 1048// CHECK-LABEL: func @fold_trivial_memory_space_cast( 1049// CHECK-SAME: %[[arg:.*]]: memref<?xf32> 1050// CHECK: return %[[arg]] 1051func.func @fold_trivial_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32> { 1052 %0 = memref.memory_space_cast %arg : memref<?xf32> to memref<?xf32> 1053 return %0 : memref<?xf32> 1054} 1055 1056// ----- 1057 1058// CHECK-LABEL: func @fold_multiple_memory_space_cast( 1059// CHECK-SAME: %[[arg:.*]]: memref<?xf32> 1060// CHECK: %[[res:.*]] = memref.memory_space_cast %[[arg]] : memref<?xf32> to memref<?xf32, 2> 1061// CHECK: return %[[res]] 1062func.func @fold_multiple_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32, 2> { 1063 %0 = memref.memory_space_cast %arg : memref<?xf32> to memref<?xf32, 1> 1064 %1 = memref.memory_space_cast %0 : memref<?xf32, 1> to memref<?xf32, 2> 1065 return %1 : memref<?xf32, 2> 1066} 1067 1068// ----- 1069 1070// CHECK-LABEL: func private @ub_negative_alloc_size 1071func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> { 1072 %idx1 = index.constant 1 1073 %c-2 = arith.constant -2 : index 1074 %c15 = arith.constant 15 : index 1075// CHECK: %[[ALLOC:.*]] = memref.alloc(%c-2) : memref<15x?x1xi1> 1076 %alloc = memref.alloc(%c15, %c-2, %idx1) : memref<?x?x?xi1> 1077 return %alloc : memref<?x?x?xi1> 1078} 1079 1080// ----- 1081 1082// CHECK-LABEL: func @subview_rank_reduction( 1083// CHECK-SAME: %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index 1084func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index) 1085 -> memref<?x?xf32, strided<[384, 1], offset: ?>> { 1086 %c1 = arith.constant 1 : index 1087 // CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>> 1088 // CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?x?xf32, strided<[384, 1], offset: ?>> 1089 %0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1] 1090 : memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>> 1091 // CHECK: return %[[cast]] 1092 return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>> 1093} 1094 1095// ----- 1096 1097// CHECK-LABEL: func @fold_double_transpose( 1098// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32> 1099func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> { 1100 // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0) 1101 %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> 1102 %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> 1103 // CHECK: return %[[ONETRANSPOSE]] 1104 return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> 1105} 1106 1107// ----- 1108 1109// CHECK-LABEL: func @fold_double_transpose2( 1110// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32> 1111func.func @fold_double_transpose2(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> { 1112 // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0) 1113 %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d0, d1, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>> 1114 %1 = memref.transpose %0 (d0, d1, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> 1115 // CHECK: return %[[ONETRANSPOSE]] 1116 return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> 1117} 1118 1119// ----- 1120 1121// CHECK-LABEL: func @fold_identity_transpose( 1122// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32> 1123func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> { 1124 %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> 1125 %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<1x2x3x4x5xf32> 1126 // CHECK: return %[[arg0]] 1127 return %1 : memref<1x2x3x4x5xf32> 1128} 1129 1130// ----- 1131 1132#transpose_map = affine_map<(d0, d1)[s0] -> (d0 + d1 * s0)> 1133 1134// CHECK-LABEL: func @cannot_fold_transpose_cast( 1135// CHECK-SAME: %[[arg0:.*]]: memref<?x4xf32> 1136func.func @cannot_fold_transpose_cast(%arg0: memref<?x4xf32>) -> memref<?x?xf32, #transpose_map> { 1137 // CHECK: %[[CAST:.*]] = memref.cast %[[arg0]] : memref<?x4xf32> to memref<?x?xf32> 1138 %cast = memref.cast %arg0 : memref<?x4xf32> to memref<?x?xf32> 1139 // CHECK: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #{{.*}}> 1140 %transpose = memref.transpose %cast (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #transpose_map> 1141 // CHECK: return %[[TRANSPOSE]] 1142 return %transpose : memref<?x?xf32, #transpose_map> 1143} 1144