1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s 2 3 4// CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)> 5// CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)> 6 7// CHECK-LABEL: split_vector_transfer_read_2d( 8// CHECK-SAME: %[[A:[a-zA-Z0-9_]*]]: memref 9// CHECK-SAME: %[[i:[a-zA-Z0-9_]*]]: index 10// CHECK-SAME: %[[j:[a-zA-Z0-9_]*]]: index 11 12func.func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -> vector<4x8xf32> { 13 %c0 = arith.constant 0 : index 14 %f0 = arith.constant 0.0 : f32 15 16 // CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index 17 // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 18 // alloca for boundary full tile 19 // CHECK: %[[alloc:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> 20 // %i + 4 <= dim(%A, 0) 21 // CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]] 22 // CHECK: %[[d0:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x8xf32> 23 // CHECK: %[[cmp0:.*]] = arith.cmpi sle, %[[idx0]], %[[d0]] : index 24 // %j + 8 <= dim(%A, 1) 25 // CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]] 26 // CHECK: %[[cmp1:.*]] = arith.cmpi sle, %[[idx1]], %[[c8]] : index 27 // are both conds true 28 // CHECK: %[[cond:.*]] = arith.andi %[[cmp0]], %[[cmp1]] : i1 29 // CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32>, index, index) { 30 // inBounds, just yield %A 31 // CHECK: scf.yield %[[A]], %[[i]], %[[j]] : memref<?x8xf32>, index, index 32 // CHECK: } else { 33 // slow path, fill tmp alloc and yield a memref_casted version of it 34 // CHECK: %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst : 35 // CHECK-SAME: memref<?x8xf32>, vector<4x8xf32> 36 // CHECK: %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] : 37 // CHECK-SAME: memref<4x8xf32> to memref<vector<4x8xf32>> 38 // CHECK: store %[[slow]], %[[cast_alloc]][] : memref<vector<4x8xf32>> 39 // CHECK: %[[yielded:.*]] = memref.cast %[[alloc]] : 40 // CHECK-SAME: memref<4x8xf32> to memref<?x8xf32> 41 // CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] : 42 // CHECK-SAME: memref<?x8xf32>, index, index 43 // CHECK: } 44 // CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %cst 45 // CHECK-SAME: {in_bounds = [true, true]} : memref<?x8xf32>, vector<4x8xf32> 46 47 %1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32>, vector<4x8xf32> 48 49 return %1: vector<4x8xf32> 50} 51 52// CHECK-LABEL: split_vector_transfer_read_strided_2d( 53// CHECK-SAME: %[[A:[a-zA-Z0-9_]*]]: memref 54// CHECK-SAME: %[[i:[a-zA-Z0-9_]*]]: index 55// CHECK-SAME: %[[j:[a-zA-Z0-9_]*]]: index 56 57func.func @split_vector_transfer_read_strided_2d( 58 %A: memref<7x8xf32, strided<[?, 1], offset: ?>>, 59 %i: index, %j: index) -> vector<4x8xf32> { 60 %c0 = arith.constant 0 : index 61 %f0 = arith.constant 0.0 : f32 62 63 // CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index 64 // CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index 65 // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 66 // alloca for boundary full tile 67 // CHECK: %[[alloc:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> 68 // %i + 4 <= dim(%A, 0) 69 // CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]] 70 // CHECK: %[[cmp0:.*]] = arith.cmpi sle, %[[idx0]], %[[c7]] : index 71 // %j + 8 <= dim(%A, 1) 72 // CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]] 73 // CHECK: %[[cmp1:.*]] = arith.cmpi sle, %[[idx1]], %[[c8]] : index 74 // are both conds true 75 // CHECK: %[[cond:.*]] = arith.andi %[[cmp0]], %[[cmp1]] : i1 76 // CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index) { 77 // inBounds but not cast-compatible: yield a memref_casted form of %A 78 // CHECK: %[[casted:.*]] = memref.cast %arg0 : 79 // CHECK-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>> to memref<?x8xf32, strided<[?, 1], offset: ?>> 80 // CHECK: scf.yield %[[casted]], %[[i]], %[[j]] : 81 // CHECK-SAME: memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index 82 // CHECK: } else { 83 // slow path, fill tmp alloc and yield a memref_casted version of it 84 // CHECK: %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst : 85 // CHECK-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>>, vector<4x8xf32> 86 // CHECK: %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] : 87 // CHECK-SAME: memref<4x8xf32> to memref<vector<4x8xf32>> 88 // CHECK: store %[[slow]], %[[cast_alloc]][] : 89 // CHECK-SAME: memref<vector<4x8xf32>> 90 // CHECK: %[[yielded:.*]] = memref.cast %[[alloc]] : 91 // CHECK-SAME: memref<4x8xf32> to memref<?x8xf32, strided<[?, 1], offset: ?>> 92 // CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] : 93 // CHECK-SAME: memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index 94 // CHECK: } 95 // CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {in_bounds = [true, true]} : 96 // CHECK-SAME: memref<?x8xf32, strided<[?, 1], offset: ?>>, vector<4x8xf32> 97 %1 = vector.transfer_read %A[%i, %j], %f0 : 98 memref<7x8xf32, strided<[?, 1], offset: ?>>, vector<4x8xf32> 99 100 // CHECK: return %[[res]] : vector<4x8xf32> 101 return %1 : vector<4x8xf32> 102} 103 104func.func @split_vector_transfer_read_mem_space(%A: memref<?x8xf32, 3>, %i: index, %j: index) -> vector<4x8xf32> { 105 %c0 = arith.constant 0 : index 106 %f0 = arith.constant 0.0 : f32 107 108 // CHECK: scf.if {{.*}} -> (memref<?x8xf32, strided<[8, 1]>>, index, index) { 109 // inBounds with a different memory space 110 // CHECK: %[[space_cast:.*]] = memref.memory_space_cast %{{.*}} : 111 // CHECK-SAME: memref<?x8xf32, 3> to memref<?x8xf32> 112 // CHECK: %[[cast:.*]] = memref.cast %[[space_cast]] : 113 // CHECK-SAME: memref<?x8xf32> to memref<?x8xf32, strided<[8, 1]>> 114 // CHECK: scf.yield %[[cast]], {{.*}} : memref<?x8xf32, strided<[8, 1]>>, index, index 115 // CHECK: } else { 116 // slow path, fill tmp alloc and yield a memref_casted version of it 117 // CHECK: %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst : 118 // CHECK-SAME: memref<?x8xf32, 3>, vector<4x8xf32> 119 // CHECK: %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] : 120 // CHECK-SAME: memref<4x8xf32> to memref<vector<4x8xf32>> 121 // CHECK: store %[[slow]], %[[cast_alloc]][] : memref<vector<4x8xf32>> 122 // CHECK: %[[yielded:.*]] = memref.cast %[[alloc]] : 123 // CHECK-SAME: memref<4x8xf32> to memref<?x8xf32, strided<[8, 1]>> 124 // CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] : 125 // CHECK-SAME: memref<?x8xf32, strided<[8, 1]>>, index, index 126 // CHECK: } 127 // CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %cst 128 // CHECK-SAME: {in_bounds = [true, true]} : memref<?x8xf32, strided<[8, 1]>>, vector<4x8xf32> 129 130 %1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32, 3>, vector<4x8xf32> 131 132 return %1: vector<4x8xf32> 133} 134 135module attributes {transform.with_named_sequence} { 136 transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { 137 %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> 138 transform.apply_patterns to %func_op { 139 transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" 140 } : !transform.op<"func.func"> 141 transform.yield 142 } 143} 144 145// ----- 146 147func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf32>, %i: index, %j: index) { 148 vector.transfer_write %V, %A[%i, %j] : 149 vector<4x8xf32>, memref<?x8xf32> 150 return 151} 152 153// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 4)> 154// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 8)> 155// CHECK: func @split_vector_transfer_write_2d( 156// CHECK-SAME: %[[VEC:.*]]: vector<4x8xf32>, 157// CHECK-SAME: %[[DEST:.*]]: memref<?x8xf32>, 158// CHECK-SAME: %[[I:.*]]: index, 159// CHECK-SAME: %[[J:.*]]: index) { 160// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index 161// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 162// CHECK-DAG: %[[CT:.*]] = arith.constant true 163// CHECK: %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> 164// CHECK: %[[VAL_8:.*]] = affine.apply #[[MAP0]]()[%[[I]]] 165// CHECK: %[[DIM0:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x8xf32> 166// CHECK: %[[DIM0_IN:.*]] = arith.cmpi sle, %[[VAL_8]], %[[DIM0]] : index 167// CHECK: %[[DIM1:.*]] = affine.apply #[[MAP1]]()[%[[J]]] 168// CHECK: %[[DIM1_IN:.*]] = arith.cmpi sle, %[[DIM1]], %[[C8]] : index 169// CHECK: %[[IN_BOUNDS:.*]] = arith.andi %[[DIM0_IN]], %[[DIM1_IN]] : i1 170// CHECK: %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]] -> 171// CHECK-SAME: (memref<?x8xf32>, index, index) { 172// CHECK: scf.yield %[[DEST]], %[[I]], %[[J]] : memref<?x8xf32>, index, index 173// CHECK: } else { 174// CHECK: %[[VAL_15:.*]] = memref.cast %[[TEMP]] 175// CHECK-SAME: : memref<4x8xf32> to memref<?x8xf32> 176// CHECK: scf.yield %[[VAL_15]], %[[C0]], %[[C0]] 177// CHECK-SAME: : memref<?x8xf32>, index, index 178// CHECK: } 179// CHECK: vector.transfer_write %[[VEC]], 180// CHECK-SAME: %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2] 181// CHECK-SAME: {in_bounds = [true, true]} : vector<4x8xf32>, memref<?x8xf32> 182// CHECK: %[[OUT_BOUNDS:.*]] = arith.xori %[[IN_BOUNDS]], %[[CT]] : i1 183// CHECK: scf.if %[[OUT_BOUNDS]] { 184// CHECK: %[[CASTED:.*]] = vector.type_cast %[[TEMP]] 185// CHECK-SAME: : memref<4x8xf32> to memref<vector<4x8xf32>> 186// CHECK: %[[RESULT_COPY:.*]] = memref.load %[[CASTED]][] 187// CHECK-SAME: : memref<vector<4x8xf32>> 188// CHECK: vector.transfer_write %[[RESULT_COPY]], 189// CHECK-SAME: %[[DEST]][%[[I]], %[[J]]] 190// CHECK-SAME: : vector<4x8xf32>, memref<?x8xf32> 191// CHECK: } 192// CHECK: return 193// CHECK: } 194 195 196module attributes {transform.with_named_sequence} { 197 transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { 198 %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> 199 transform.apply_patterns to %func_op { 200 transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" 201 } : !transform.op<"func.func"> 202 transform.yield 203 } 204} 205 206// ----- 207 208func.func @split_vector_transfer_write_strided_2d( 209 %V: vector<4x8xf32>, %A: memref<7x8xf32, strided<[?, 1], offset: ?>>, 210 %i: index, %j: index) { 211 vector.transfer_write %V, %A[%i, %j] : 212 vector<4x8xf32>, memref<7x8xf32, strided<[?, 1], offset: ?>> 213 return 214} 215 216// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 4)> 217// CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 + 8)> 218// CHECK: func @split_vector_transfer_write_strided_2d( 219// CHECK-SAME: %[[VEC:.*]]: vector<4x8xf32>, 220// CHECK-SAME: %[[DEST:.*]]: memref<7x8xf32, strided<[?, 1], offset: ?>>, 221// CHECK-SAME: %[[I:.*]]: index, 222// CHECK-SAME: %[[J:.*]]: index) { 223// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index 224// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index 225// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 226// CHECK-DAG: %[[CT:.*]] = arith.constant true 227// CHECK: %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> 228// CHECK: %[[DIM0:.*]] = affine.apply #[[MAP1]]()[%[[I]]] 229// CHECK: %[[DIM0_IN:.*]] = arith.cmpi sle, %[[DIM0]], %[[C7]] : index 230// CHECK: %[[DIM1:.*]] = affine.apply #[[MAP2]]()[%[[J]]] 231// CHECK: %[[DIM1_IN:.*]] = arith.cmpi sle, %[[DIM1]], %[[C8]] : index 232// CHECK: %[[IN_BOUNDS:.*]] = arith.andi %[[DIM0_IN]], %[[DIM1_IN]] : i1 233// CHECK: %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]] 234// CHECK-SAME: -> (memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index) { 235// CHECK: %[[VAL_15:.*]] = memref.cast %[[DEST]] 236// CHECK-SAME: : memref<7x8xf32, strided<[?, 1], offset: ?>> to memref<?x8xf32, strided<[?, 1], offset: ?>> 237// CHECK: scf.yield %[[VAL_15]], %[[I]], %[[J]] 238// CHECK-SAME: : memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index 239// CHECK: } else { 240// CHECK: %[[VAL_16:.*]] = memref.cast %[[TEMP]] 241// CHECK-SAME: : memref<4x8xf32> to memref<?x8xf32, strided<[?, 1], offset: ?>> 242// CHECK: scf.yield %[[VAL_16]], %[[C0]], %[[C0]] 243// CHECK-SAME: : memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index 244// CHECK: } 245// CHECK: vector.transfer_write %[[VEC]], 246// CHECK-SAME: %[[IN_BOUND_DEST:.*]]#0 247// CHECK-SAME: [%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2] 248// CHECK-SAME: {in_bounds = [true, true]} : vector<4x8xf32>, memref<?x8xf32, strided<[?, 1], offset: ?>> 249// CHECK: %[[OUT_BOUNDS:.*]] = arith.xori %[[IN_BOUNDS]], %[[CT]] : i1 250// CHECK: scf.if %[[OUT_BOUNDS]] { 251// CHECK: %[[VAL_19:.*]] = vector.type_cast %[[TEMP]] 252// CHECK-SAME: : memref<4x8xf32> to memref<vector<4x8xf32>> 253// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_19]][] 254// CHECK-SAME: : memref<vector<4x8xf32>> 255// CHECK: vector.transfer_write %[[VAL_20]], %[[DEST]][%[[I]], %[[J]]] 256// CHECK-SAME: : vector<4x8xf32>, memref<7x8xf32, strided<[?, 1], offset: ?>> 257// CHECK: } 258// CHECK: return 259// CHECK: } 260 261module attributes {transform.with_named_sequence} { 262 transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { 263 %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> 264 transform.apply_patterns to %func_op { 265 transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" 266 } : !transform.op<"func.func"> 267 transform.yield 268 } 269} 270 271// ----- 272 273func.func @split_vector_transfer_write_mem_space(%V: vector<4x8xf32>, %A: memref<?x8xf32, 3>, %i: index, %j: index) { 274 vector.transfer_write %V, %A[%i, %j] : 275 vector<4x8xf32>, memref<?x8xf32, 3> 276 return 277} 278 279// CHECK: func @split_vector_transfer_write_mem_space( 280// CHECK: scf.if {{.*}} -> (memref<?x8xf32, strided<[8, 1]>>, index, index) { 281// CHECK: %[[space_cast:.*]] = memref.memory_space_cast %{{.*}} : 282// CHECK-SAME: memref<?x8xf32, 3> to memref<?x8xf32> 283// CHECK: %[[cast:.*]] = memref.cast %[[space_cast]] : 284// CHECK-SAME: memref<?x8xf32> to memref<?x8xf32, strided<[8, 1]>> 285// CHECK: scf.yield %[[cast]], {{.*}} : memref<?x8xf32, strided<[8, 1]>>, index, index 286// CHECK: } else { 287// CHECK: %[[VAL_15:.*]] = memref.cast %[[TEMP]] 288// CHECK-SAME: : memref<4x8xf32> to memref<?x8xf32, strided<[8, 1]>> 289// CHECK: scf.yield %[[VAL_15]], %[[C0]], %[[C0]] 290// CHECK-SAME: : memref<?x8xf32, strided<[8, 1]>>, index, index 291// CHECK: } 292// CHECK: vector.transfer_write %[[VEC]], 293// CHECK-SAME: %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2] 294// CHECK-SAME: {in_bounds = [true, true]} : vector<4x8xf32>, memref<?x8xf32, strided<[8, 1]>> 295 296 297module attributes {transform.with_named_sequence} { 298 transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { 299 %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> 300 transform.apply_patterns to %func_op { 301 transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" 302 } : !transform.op<"func.func"> 303 transform.yield 304 } 305} 306 307 308// ----- 309 310func.func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> () 311 312// CHECK-LABEL: transfer_read_within_async_execute 313func.func @transfer_read_within_async_execute(%A : memref<?x?xf32>) -> !async.token { 314 %c0 = arith.constant 0 : index 315 %f0 = arith.constant 0.0 : f32 316 // CHECK-NOT: alloca 317 // CHECK: async.execute 318 // CHECK: alloca 319 %token = async.execute { 320 %0 = vector.transfer_read %A[%c0, %c0], %f0 : memref<?x?xf32>, vector<2x2xf32> 321 func.call @fake_side_effecting_fun(%0) : (vector<2x2xf32>) -> () 322 async.yield 323 } 324 return %token : !async.token 325} 326 327// Ensure that `alloca`s are inserted outside of loops even though loops are 328// consdered allocation scopes. 329// CHECK-LABEL: transfer_read_within_scf_for 330func.func @transfer_read_within_scf_for(%A : memref<?x?xf32>, %lb : index, %ub : index, %step : index) { 331 %c0 = arith.constant 0 : index 332 %f0 = arith.constant 0.0 : f32 333 // CHECK: memref.alloca 334 // CHECK: scf.for 335 // CHECK-NOT: memref.alloca 336 scf.for %i = %lb to %ub step %step { 337 %0 = vector.transfer_read %A[%c0, %c0], %f0 : memref<?x?xf32>, vector<2x2xf32> 338 func.call @fake_side_effecting_fun(%0) : (vector<2x2xf32>) -> () 339 } 340 return 341} 342 343module attributes {transform.with_named_sequence} { 344 transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { 345 %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> 346 transform.apply_patterns to %func_op { 347 transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" 348 } : !transform.op<"func.func"> 349 transform.yield 350 } 351} 352