1// RUN: mlir-opt %s --transform-interpreter | FileCheck %s 2 3//----------------------------------------------------------------------------- 4// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern] 5//----------------------------------------------------------------------------- 6 7func.func @transfer_read_rank_reducing( 8 %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> { 9 %c0 = arith.constant 0 : index 10 %cst = arith.constant 0 : i8 11 %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : 12 memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8> 13 return %v : vector<3x2xi8> 14} 15// CHECK-LABEL: func @transfer_read_rank_reducing 16// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 17// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] 18// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> 19// CHECK: vector.transfer_read %[[SUBVIEW]] 20 21func.func @transfer_read_rank_reducing_masked( 22 %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, 23 %mask: vector<3x2xi1>) -> vector<3x2xi8> { 24 %c0 = arith.constant 0 : index 25 %cst = arith.constant 0 : i8 26 %v = vector.mask %mask { 27 vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : 28 memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8> 29 } : vector<3x2xi1> -> vector<3x2xi8> 30 return %v : vector<3x2xi8> 31} 32// CHECK-LABEL: func @transfer_read_rank_reducing_masked 33// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 34// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1> 35// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] 36// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> 37// CHECK: vector.mask %[[MASK]] 38// CHECK-SAME: vector.transfer_read %[[SUBVIEW]] 39 40func.func @transfer_write_rank_reducing( 41 %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, 42 %vec : vector<3x2xi8>) { 43 44 %c0 = arith.constant 0 : index 45 vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : 46 vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>> 47 return 48} 49// CHECK-LABEL: func @transfer_write_rank_reducing 50// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 51// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] 52// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> 53// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]] 54 55func.func @transfer_write_rank_reducing_masked( 56 %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, 57 %vec : vector<3x2xi8>, 58 %mask: vector<3x2xi1>) { 59 %c0 = arith.constant 0 : index 60 vector.mask %mask { 61 vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : 62 vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>> 63 } : vector<3x2xi1> 64 return 65} 66// CHECK-LABEL: func @transfer_write_rank_reducing_masked 67// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 68// CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8> 69// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1> 70// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] 71// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> 72// CHECK: vector.mask %[[MASK]] 73// CHECK-SAME: vector.transfer_write %{{.*}}, %[[SUBVIEW]] 74 75func.func @transfer_read_and_vector_rank_reducing( 76 %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> { 77 %c0 = arith.constant 0 : index 78 %cst = arith.constant 0.0 : f32 79 %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst : 80 memref<1x1x3x2x1xf32>, vector<3x2x1xf32> 81 return %v : vector<3x2x1xf32> 82} 83// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing 84// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32> 85// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1] 86// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32> 87// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : memref<3x2xf32>, vector<3x2xf32> 88 89func.func @transfer_write_and_vector_rank_reducing( 90 %arg : memref<1x1x3x2x1xf32>, 91 %vec : vector<3x2x1xf32>) { 92 %c0 = arith.constant 0 : index 93 vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] : 94 vector<3x2x1xf32>, memref<1x1x3x2x1xf32> 95 return 96} 97// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing 98// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32> 99// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1] 100// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32> 101// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x2xf32>, memref<3x2xf32> 102 103func.func @transfer_read_and_vector_rank_reducing_to_0d( 104 %arg : memref<1x1x1x1x1xf32>) -> vector<1x1x1xf32> { 105 %c0 = arith.constant 0 : index 106 %cst = arith.constant 0.0 : f32 107 %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst : 108 memref<1x1x1x1x1xf32>, vector<1x1x1xf32> 109 return %v : vector<1x1x1xf32> 110} 111// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d 112// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32> 113// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32> 114// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32> 115// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32> 116 117func.func @transfer_read_and_vector_rank_reducing_to_0d_masked( 118 %arg : memref<1x1x1x1x1xf32>, 119 %mask: vector<1x1x1xi1>) -> vector<1x1x1xf32> { 120 121 %c0 = arith.constant 0 : index 122 %cst = arith.constant 0.0 : f32 123 %v = vector.mask %mask { 124 vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst 125 : memref<1x1x1x1x1xf32>, vector<1x1x1xf32> 126 } : vector<1x1x1xi1> -> vector<1x1x1xf32> 127 return %v : vector<1x1x1xf32> 128} 129// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d_masked 130// CHECK-NOT: vector.shape_cast 131// CHECK-NOT: memref.subview 132 133func.func @transfer_write_and_vector_rank_reducing_to_0d( 134 %arg : memref<1x1x1x1x1xf32>, 135 %vec : vector<1x1x1xf32>) { 136 %c0 = arith.constant 0 : index 137 vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] : 138 vector<1x1x1xf32>, memref<1x1x1x1x1xf32> 139 return 140} 141// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d 142// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>, %[[VECTOR:.+]]: vector<1x1x1xf32> 143// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32> 144// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32> 145// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32> 146 147func.func @transfer_write_and_vector_rank_reducing_to_0d_masked( 148 %arg : memref<1x1x1x1x1xf32>, 149 %vec : vector<1x1x1xf32>, 150 %mask: vector<1x1x1xi1>) { 151 152 %c0 = arith.constant 0 : index 153 %cst = arith.constant 0.0 : f32 154 vector.mask %mask { 155 vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0] : 156 vector<1x1x1xf32>, memref<1x1x1x1x1xf32> 157 } : vector<1x1x1xi1> 158 return 159} 160// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d_masked 161// CHECK-NOT: vector.shape_cast 162// CHECK-NOT: memref.subview 163 164func.func @transfer_read_dynamic_rank_reducing( 165 %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> { 166 %c0 = arith.constant 0 : index 167 %pad = arith.constant 0 : i8 168 %v = vector.transfer_read %arg[%c0, %c0], %pad {in_bounds = [true, true]} : 169 memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8> 170 return %v : vector<[16]x1xi8> 171} 172// CHECK-LABEL: func @transfer_read_dynamic_rank_reducing 173// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8 174// CHECK: %[[C0:.+]] = arith.constant 0 : index 175// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>> 176// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}> 177// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<?xi8, {{.*}}>, vector<[16]xi8> 178 179func.func @masked_transfer_read_dynamic_rank_reducing_1( 180 %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>, 181 %mask_dim0 : index) -> vector<[16]x1xi8> { 182 %c0 = arith.constant 0 : index 183 %c1 = arith.constant 1 : index 184 %pad = arith.constant 0 : i8 185 %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1> 186 %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : 187 memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8> 188 return %v : vector<[16]x1xi8> 189} 190// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1 191// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8 192// CHECK-SAME: %[[MASK_DIM0:.+]]: index 193// CHECK: %[[C0:.+]] = arith.constant 0 : index 194// CHECK: %[[PAD:.+]] = arith.constant 0 : i8 195// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1> 196// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>> 197// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}> 198// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8> 199 200func.func @masked_transfer_read_dynamic_rank_reducing_2( 201 %arg : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, 202 %mask_dim1 : index, %mask_dim4 : index) -> vector<1x[1]x3x1x[16]x1xi8> { 203 %c0 = arith.constant 0 : index 204 %c1 = arith.constant 1 : index 205 %c2 = arith.constant 2 : index 206 %pad = arith.constant 0 : i8 207 %mask = vector.create_mask %c1, %mask_dim1, %c2, %c1, %mask_dim4, %c1 : vector<1x[1]x3x1x[16]x1xi1> 208 %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true, true, true, true]} : 209 memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, vector<1x[1]x3x1x[16]x1xi8> 210 return %v : vector<1x[1]x3x1x[16]x1xi8> 211} 212// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2 213// CHECK-SAME: %[[ARG:.+]]: memref<1x?x3x1x?x1xi8 214// CHECK-SAME: %[[MASK_DIM1:.+]]: index, %[[MASK_DIM4:.+]]: index 215// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 216// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 217// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index 218// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index 219// CHECK-DAG: %[[PAD:.+]] = arith.constant 0 : i8 220// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM1]], %[[C2]], %[[MASK_DIM4]] : vector<[1]x3x[16]xi1> 221// CHECK: %[[DIM1:.+]] = memref.dim %[[ARG]], %[[C1]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>> 222// CHECK: %[[DIM4:.+]] = memref.dim %[[ARG]], %[[C4]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>> 223// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}> 224// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8> 225 226func.func @masked_transfer_write_and_vector_rank_reducing( 227 %arg : memref<1x1x3x1x16x1xf32>, 228 %vec : vector<1x3x1x16x1xf32>, 229 %mask_dim1 : index, 230 %mask_dim2 : index) { 231 %c0 = arith.constant 0 : index 232 %c1 = arith.constant 1 : index 233 %mask = vector.create_mask %c1, %mask_dim1, %c1, %mask_dim2, %c1 : vector<1x3x1x16x1xi1> 234 vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask : 235 vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32> 236 return 237} 238// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing 239// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32> 240// CHECK-SAME: {{.*}}: vector<1x3x1x16x1xf32>, 241// CHECK-SAME: %[[MASKDIM1:.+]]: index, 242// CHECK-SAME: %[[MASKDIM2:.+]]: index 243// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASKDIM1]], %[[MASKDIM2]] : vector<3x16xi1> 244// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1] 245// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32> 246// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32> 247 248func.func @masked_transfer_write_dynamic_rank_reducing( 249 %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>, 250 %vec : vector<[16]x1xi8>, 251 %mask_dim0 : index) { 252 %c0 = arith.constant 0 : index 253 %c1 = arith.constant 1 : index 254 %pad = arith.constant 0 : i8 255 %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1> 256 vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} : 257 vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>> 258 return 259} 260// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing 261// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8 262// CHECK-SAME: %{{.*}}: vector<[16]x1xi8>, 263// CHECK-SAME: %[[MASK_DIM0:.+]]: index 264// CHECK: %[[C0:.+]] = arith.constant 0 : index 265// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1> 266// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>> 267// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}> 268// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}> 269 270/// Only masks operands of vector.create_mask are currently supported. 271func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1( 272 %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>, 273 %mask : vector<[16]x1xi1>) -> vector<[16]x1xi8> { 274 %c0 = arith.constant 0 : index 275 %pad = arith.constant 0 : i8 276 %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : 277 memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8> 278 return %v : vector<[16]x1xi8> 279} 280// CHECK-LABEL: func @unsupported_masked_transfer_read_dynamic_rank_reducing_1 281// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8 282// CHECK-NOT: vector.create_mask 283// CHECK-NOT: memref.subview 284// CHECK: vector.transfer_read %[[ARG]] 285 286/// Unit dim mask must be constant of 1. 287func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_2( 288 %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>, 289 %mask_dim0 : index, %mask_dim1 : index) -> vector<[16]x1xi8> { 290 %c0 = arith.constant 0 : index 291 %c1 = arith.constant 1 : index 292 %pad = arith.constant 0 : i8 293 %mask = vector.create_mask %mask_dim0, %mask_dim1 : vector<[16]x1xi1> 294 %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : 295 memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8> 296 return %v : vector<[16]x1xi8> 297} 298// CHECK-LABEL: func @unsupported_masked_transfer_read_dynamic_rank_reducing_2 299// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8 300// CHECK-NOT: memref.subview 301// CHECK: vector.transfer_read {{.*}} vector<[16]x1xi8> 302 303/// Unit dim must be non-scalable. 304func.func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim( 305 %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>, 306 %mask_dim0 : index) -> vector<[16]x[1]xi8> { 307 %c0 = arith.constant 0 : index 308 %c1 = arith.constant 1 : index 309 %pad = arith.constant 0 : i8 310 %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x[1]xi1> 311 %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : 312 memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x[1]xi8> 313 return %v : vector<[16]x[1]xi8> 314} 315// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim 316// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8 317// CHECK-NOT: memref.subview 318// CHECK: vector.transfer_read {{.*}} vector<[16]x[1]xi8> 319 320module attributes {transform.with_named_sequence} { 321 transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { 322 %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> 323 transform.apply_patterns to %func_op { 324 transform.apply_patterns.vector.rank_reducing_subview_patterns 325 } : !transform.op<"func.func"> 326 transform.yield 327 } 328} 329