1// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s 2// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B 3 4// TODO: Align naming and format with e.g. vector-transfer-permutation-lowering.mlir 5 6///---------------------------------------------------------------------------------------- 7/// vector.transfer_read 8/// [Pattern: FlattenContiguousRowMajorTransferReadPattern] 9/// 10/// NOTE: Scalable vectors are not supported 11///---------------------------------------------------------------------------------------- 12 13func.func @transfer_read_dims_match_contiguous( 14 %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { 15 16 %c0 = arith.constant 0 : index 17 %cst = arith.constant 0 : i8 18 %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : 19 memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8> 20 return %res : vector<5x4x3x2xi8> 21} 22 23// CHECK-LABEL: func @transfer_read_dims_match_contiguous 24// CHECK-SAME: %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 25// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3] 26// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] 27// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> 28// CHECK: return %[[VEC2D]] 29 30// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous 31// CHECK-128B: memref.collapse_shape 32 33func.func @transfer_read_dims_match_contiguous_scalable( 34 %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x[2]xi8> { 35 36 %c0 = arith.constant 0 : index 37 %cst = arith.constant 0 : i8 38 %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : 39 memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x[2]xi8> 40 return %res : vector<5x4x3x[2]xi8> 41} 42 43// CHECK-LABEL: func @transfer_read_dims_match_contiguous_scalable 44// CHECK-NOT: memref.collapse_shape 45 46// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous_scalable 47// CHECK-128B-NOT: memref.collapse_shape 48 49// ----- 50 51func.func @transfer_read_dims_match_contiguous_empty_stride( 52 %mem : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> { 53 54 %c0 = arith.constant 0 : index 55 %cst = arith.constant 0 : i8 56 %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : 57 memref<5x4x3x2xi8>, vector<5x4x3x2xi8> 58 return %res : vector<5x4x3x2xi8> 59} 60 61// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride( 62// CHECK-SAME: %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 63// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3] 64// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] 65// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> 66// CHECK: return %[[VEC2D]] 67 68// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous_empty_stride( 69// CHECK-128B: memref.collapse_shape 70 71// ----- 72 73// The shape of the memref and the vector don't match, but the vector is a 74// contiguous subset of the memref, so "flattenable". 75 76func.func @transfer_read_dims_mismatch_contiguous( 77 %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { 78 79 %c0 = arith.constant 0 : index 80 %cst = arith.constant 0 : i8 81 %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : 82 memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8> 83 return %res : vector<1x1x2x2xi8> 84} 85 86// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous( 87// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { 88// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8 89// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index 90// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> 91// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8> 92// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8> 93// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8> 94 95// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous( 96// CHECK-128B: memref.collapse_shape 97 98// ----- 99 100func.func @transfer_read_dims_mismatch_non_zero_indices( 101 %idx_1: index, 102 %idx_2: index, 103 %mem: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{ 104 105 %c0 = arith.constant 0 : index 106 %c0_i32 = arith.constant 0 : i32 107 %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 { 108 in_bounds = [true, true, true] 109 } : memref<1x43x4x6xi32>, vector<1x2x6xi32> 110 return %res : vector<1x2x6xi32> 111} 112 113// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> 114 115// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices( 116// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, 117// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32> 118// CHECK: %[[C_0:.*]] = arith.constant 0 : i32 119// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index 120// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> 121// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]] 122// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32> 123 124// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices( 125// CHECK-128B-NOT: memref.collapse_shape 126 127// ----- 128 129// Overall, the source memref is non-contiguous. However, the slice from which 130// the output vector is to be read _is_ contiguous. Hence the flattening works fine. 131 132func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( 133 %mem : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, 134 %idx_1 : index, 135 %idx_2 : index) -> vector<2x2xf32> { 136 137 %c0 = arith.constant 0 : index 138 %cst_1 = arith.constant 0.000000e+00 : f32 139 %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %cst_1 { 140 in_bounds = [true, true] 141 } : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32> 142 return %res : vector<2x2xf32> 143} 144 145// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> 146 147// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( 148// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] 149// CHECK-SAME: : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> 150// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]() 151 152// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( 153// CHECK-128B: memref.collapse_shape 154 155// ----- 156 157// The leading dynamic shapes don't affect whether this example is flattenable 158// or not. Indeed, those dynamic shapes are not candidates for flattening anyway. 159 160func.func @transfer_read_leading_dynamic_dims( 161 %mem : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, 162 %idx_1 : index, 163 %idx_2 : index) -> vector<8x4xi8> { 164 165 %c0_i8 = arith.constant 0 : i8 166 %c0 = arith.constant 0 : index 167 %res = vector.transfer_read %mem[%idx_1, %idx_2, %c0, %c0], %c0_i8 { 168 in_bounds = [true, true] 169 } : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8> 170 return %res : vector<8x4xi8> 171} 172 173// CHECK-LABEL: func @transfer_read_leading_dynamic_dims 174// CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index 175// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 176// CHECK: %[[C0:.+]] = arith.constant 0 : index 177// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}} 178// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}> 179// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]] 180// CHECK-SAME: [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]] 181// CHECK-SAME: {in_bounds = [true]} 182// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8> 183// CHECK: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8> 184// CHECK: return %[[RES]] : vector<8x4xi8> 185 186// CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims 187// CHECK-128B: memref.collapse_shape 188 189// ----- 190 191// One of the dims to be flattened is dynamic - not supported ATM. 192 193func.func @negative_transfer_read_dynamic_dim_to_flatten( 194 %idx_1: index, 195 %idx_2: index, 196 %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> { 197 198 %c0 = arith.constant 0 : index 199 %c0_i32 = arith.constant 0 : i32 200 %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 { 201 in_bounds = [true, true, true] 202 } : memref<1x?x4x6xi32>, vector<1x2x6xi32> 203 return %res : vector<1x2x6xi32> 204} 205 206// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten 207// CHECK-NOT: memref.collapse_shape 208// CHECK-NOT: vector.shape_cast 209 210// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten 211// CHECK-128B-NOT: memref.collapse_shape 212 213// ----- 214 215// The vector to be read represents a _non-contiguous_ slice of the input 216// memref. 217 218func.func @transfer_read_dims_mismatch_non_contiguous_slice( 219 %mem : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> { 220 221 %c0 = arith.constant 0 : index 222 %cst = arith.constant 0 : i8 223 %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : 224 memref<5x4x3x2xi8>, vector<2x1x2x2xi8> 225 return %res : vector<2x1x2x2xi8> 226} 227 228// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice( 229// CHECK-NOT: memref.collapse_shape 230// CHECK-NOT: vector.shape_cast 231 232// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_slice( 233// CHECK-128B-NOT: memref.collapse_shape 234 235// ----- 236 237func.func @transfer_read_0d( 238 %mem : memref<i8>) -> vector<i8> { 239 240 %cst = arith.constant 0 : i8 241 %res = vector.transfer_read %mem[], %cst : memref<i8>, vector<i8> 242 return %res : vector<i8> 243} 244 245// CHECK-LABEL: func.func @transfer_read_0d 246// CHECK-NOT: memref.collapse_shape 247// CHECK-NOT: vector.shape_cast 248 249// CHECK-128B-LABEL: func @transfer_read_0d( 250// CHECK-128B-NOT: memref.collapse_shape 251// CHECK-128B-NOT: vector.shape_cast 252 253// ----- 254 255// Strides make the input memref non-contiguous, hence non-flattenable. 256 257func.func @transfer_read_non_contiguous_src( 258 %mem : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { 259 260 %c0 = arith.constant 0 : index 261 %cst = arith.constant 0 : i8 262 %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : 263 memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8> 264 return %res : vector<5x4x3x2xi8> 265} 266 267// CHECK-LABEL: func.func @transfer_read_non_contiguous_src 268// CHECK-NOT: memref.collapse_shape 269// CHECK-NOT: vector.shape_cast 270 271// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src 272// CHECK-128B-NOT: memref.collapse_shape 273// CHECK-128B-NOT: vector.shape_cast 274 275// ----- 276 277///---------------------------------------------------------------------------------------- 278/// vector.transfer_write 279/// [Pattern: FlattenContiguousRowMajorTransferWritePattern] 280/// 281/// NOTE: Scalable vectors are not supported 282///---------------------------------------------------------------------------------------- 283 284func.func @transfer_write_dims_match_contiguous( 285 %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, 286 %vec : vector<5x4x3x2xi8>) { 287 288 %c0 = arith.constant 0 : index 289 vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] : 290 vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> 291 return 292} 293 294// CHECK-LABEL: func @transfer_write_dims_match_contiguous( 295// CHECK-SAME: %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 296// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> 297// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> 298// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> 299// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] 300 301// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous( 302// CHECK-128B: memref.collapse_shape 303 304func.func @transfer_write_dims_match_contiguous_scalable( 305 %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, 306 %vec : vector<5x4x3x[2]xi8>) { 307 308 %c0 = arith.constant 0 : index 309 vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] : 310 vector<5x4x3x[2]xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> 311 return 312} 313 314// CHECK-LABEL: func @transfer_write_dims_match_contiguous_scalable( 315// CHECK-NOT: memref.collapse_shape 316 317// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous_scalable 318// CHECK-128B-NOT: memref.collapse_shape 319 320// ----- 321 322func.func @transfer_write_dims_match_contiguous_empty_stride( 323 %mem : memref<5x4x3x2xi8>, 324 %vec : vector<5x4x3x2xi8>) { 325 326 %c0 = arith.constant 0 : index 327 vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] : 328 vector<5x4x3x2xi8>, memref<5x4x3x2xi8> 329 return 330} 331 332// CHECK-LABEL: func @transfer_write_dims_match_contiguous_empty_stride( 333// CHECK-SAME: %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 334// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> 335// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8> 336// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> 337// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] 338 339// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous_empty_stride( 340// CHECK-128B: memref.collapse_shape 341 342// ----- 343 344func.func @transfer_write_dims_mismatch_contiguous( 345 %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, 346 %vec : vector<1x1x2x2xi8>) { 347 348 %c0 = arith.constant 0 : index 349 vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] : 350 vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> 351 return 352} 353 354// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous 355// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, 356// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) { 357// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index 358// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> 359// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8> 360// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>> 361 362// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous( 363// CHECK-128B: memref.collapse_shape 364 365// ----- 366 367func.func @transfer_write_dims_mismatch_non_zero_indices( 368 %idx_1: index, 369 %idx_2: index, 370 %mem: memref<1x43x4x6xi32>, 371 %vec: vector<1x2x6xi32>) { 372 373 %c0 = arith.constant 0 : index 374 %c0_i32 = arith.constant 0 : i32 375 vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : 376 vector<1x2x6xi32>, memref<1x43x4x6xi32> 377 return 378} 379 380// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> 381 382// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices( 383// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, 384// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>, 385// CHECK-SAME: %[[VEC:.*]]: vector<1x2x6xi32>) { 386// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 387// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]] 388// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> 389// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32> 390// CHECK: vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32> 391 392// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices( 393// CHECK-128B-NOT: memref.collapse_shape 394 395// ----- 396 397// Overall, the destination memref is non-contiguous. However, the slice to 398// which the input vector is to be written _is_ contiguous. Hence the 399// flattening works fine. 400 401func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( 402 %vec : vector<2x2xf32>, 403 %mem : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, 404 %idx_1 : index, 405 %idx_2 : index) { 406 407 %c0 = arith.constant 0 : index 408 vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> 409 return 410} 411 412// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> 413 414// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( 415// CHECK-DAG: %[[APPLY:.*]] = affine.apply #[[$MAP]]() 416// CHECK-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> 417 418// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( 419// CHECK-128B: memref.collapse_shape 420 421// ----- 422 423// The leading dynamic shapes don't affect whether this example is flattenable 424// or not. Indeed, those dynamic shapes are not candidates for flattening anyway. 425 426func.func @transfer_write_leading_dynamic_dims( 427 %vec : vector<8x4xi8>, 428 %mem : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, 429 %idx_1 : index, 430 %idx_2 : index) { 431 432 %c0 = arith.constant 0 : index 433 vector.transfer_write %vec, %mem[%idx_1, %idx_2, %c0, %c0] {in_bounds = [true, true]} : 434 vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>> 435 return 436} 437 438// CHECK-LABEL: func @transfer_write_leading_dynamic_dims 439// CHECK-SAME: %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index 440// CHECK: %[[C0:.+]] = arith.constant 0 : index 441// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}} 442// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}> 443// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8> 444// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] 445// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]] 446// CHECK-SAME: {in_bounds = [true]} 447// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}> 448 449// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims 450// CHECK-128B: memref.collapse_shape 451 452// ----- 453 454// One of the dims to be flattened is dynamic - not supported ATM. 455 456func.func @negative_transfer_write_dynamic_to_flatten( 457 %idx_1: index, 458 %idx_2: index, 459 %vec : vector<1x2x6xi32>, 460 %mem: memref<1x?x4x6xi32>) { 461 462 %c0 = arith.constant 0 : index 463 %c0_i32 = arith.constant 0 : i32 464 vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : 465 vector<1x2x6xi32>, memref<1x?x4x6xi32> 466 return 467} 468 469// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten 470// CHECK-NOT: memref.collapse_shape 471// CHECK-NOT: vector.shape_cast 472 473// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten 474// CHECK-128B-NOT: memref.collapse_shape 475 476// ----- 477 478// The vector to be written represents a _non-contiguous_ slice of the output 479// memref. 480 481func.func @transfer_write_dims_mismatch_non_contiguous_slice( 482 %mem : memref<5x4x3x2xi8>, 483 %vec : vector<2x1x2x2xi8>) { 484 485 %c0 = arith.constant 0 : index 486 %cst = arith.constant 0 : i8 487 vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] : 488 vector<2x1x2x2xi8>, memref<5x4x3x2xi8> 489 return 490} 491 492// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous_slice( 493// CHECK-NOT: memref.collapse_shape 494// CHECK-NOT: vector.shape_cast 495 496// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_slice( 497// CHECK-128B-NOT: memref.collapse_shape 498 499// ----- 500 501func.func @transfer_write_0d( 502 %mem : memref<i8>, 503 %vec : vector<i8>) { 504 505 vector.transfer_write %vec, %mem[] : vector<i8>, memref<i8> 506 return 507} 508 509// CHECK-LABEL: func.func @transfer_write_0d 510// CHECK-NOT: memref.collapse_shape 511// CHECK-NOT: vector.shape_cast 512 513// CHECK-128B-LABEL: func @transfer_write_0d( 514// CHECK-128B-NOT: memref.collapse_shape 515// CHECK-128B-NOT: vector.shape_cast 516 517// ----- 518 519// The strides make the input memref non-contiguous, hence non-flattenable. 520 521func.func @transfer_write_non_contiguous_src( 522 %mem : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, 523 %vec : vector<5x4x3x2xi8>) { 524 525 %c0 = arith.constant 0 : index 526 vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] : 527 vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>> 528 return 529} 530 531// CHECK-LABEL: func.func @transfer_write_non_contiguous_src 532// CHECK-NOT: memref.collapse_shape 533// CHECK-NOT: vector.shape_cast 534 535// CHECK-128B-LABEL: func @transfer_write_non_contiguous_src 536// CHECK-128B-NOT: memref.collapse_shape 537// CHECK-128B-NOT: vector.shape_cast 538 539// ----- 540 541func.func @negative_out_of_bound_transfer_read( 542 %mem : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { 543 %c0 = arith.constant 0 : index 544 %cst = arith.constant 0 : i8 545 %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst {in_bounds = [false, true, true, true]} : 546 memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8> 547 return %res : vector<5x4x3x2xi8> 548} 549// CHECK: func.func @negative_out_of_bound_transfer_read 550// CHECK-NOT: memref.collapse_shape 551 552// ----- 553 554func.func @negative_out_of_bound_transfer_write( 555 %mem : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x3x2xi8>) { 556 %c0 = arith.constant 0 : index 557 vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] {in_bounds = [false, true, true, true]} : 558 vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> 559 return 560} 561// CHECK: func.func @negative_out_of_bound_transfer_write 562// CHECK-NOT: memref.collapse_shape 563