1// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s 2 3// CHECK-LABEL: func @hoist_vector_transfer_pairs( 4// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>, 5// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>, 6// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>, 7// CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>, 8// CHECK-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>, 9// CHECK-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>, 10// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index, 11// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index, 12// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index, 13// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index, 14// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1 15func.func @hoist_vector_transfer_pairs( 16 %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>, 17 %memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>, 18 %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) { 19 %c0 = arith.constant 0 : index 20 %cst = arith.constant 0.0 : f32 21 22// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32> 23// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) { 24// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32> 25// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) { 26// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32> 27// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32> 28// CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> () 29// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32> 30// CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> 31// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> 32// CHECK: "some_use"(%[[MEMREF2]], %{{.*}}) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32> 33// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> 34// CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32> 35// CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32> 36// CHECK: vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32> 37// CHECK: vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32> 38// CHECK: "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> () 39// CHECK: scf.yield {{.*}} : vector<1xf32>, vector<2xf32> 40// CHECK: } 41// CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32> 42// CHECK: "unrelated_use"(%[[MEMREF0]]) : (memref<?x?xf32>) -> () 43// CHECK: scf.yield {{.*}} : vector<1xf32> 44// CHECK: } 45// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32> 46// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref<?x?xf32>) -> () 47 scf.for %i = %lb to %ub step %step { 48 scf.for %j = %lb to %ub step %step { 49 %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32> 50 %r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32> 51 %r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32> 52 %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32> 53 "some_crippling_use"(%memref4) : (memref<?x?xf32>) -> () 54 %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32> 55 %r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref<?x?xf32>, vector<6xf32> 56 "some_crippling_use"(%memref5) : (memref<?x?xf32>) -> () 57 %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> 58 %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> 59 %u2 = "some_use"(%memref2, %r2) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32> 60 %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> 61 %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> 62 %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32> 63 vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32> 64 vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32> 65 vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32> 66 vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32> 67 vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32> 68 vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref<?x?xf32> 69 "some_crippling_use"(%memref3) : (memref<?x?xf32>) -> () 70 } 71 "unrelated_use"(%memref0) : (memref<?x?xf32>) -> () 72 } 73 "unrelated_use"(%memref1) : (memref<?x?xf32>) -> () 74 return 75} 76 77module attributes {transform.with_named_sequence} { 78 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 79 %0 = transform.structured.match ops{["func.func"]} in %arg1 80 : (!transform.any_op) -> !transform.any_op 81 transform.structured.hoist_redundant_vector_transfers %0 82 : (!transform.any_op) -> !transform.any_op 83 transform.yield 84 } 85} 86 87// ----- 88 89// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint( 90// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>, 91// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>, 92// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>, 93// CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>, 94// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index, 95// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index, 96// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index, 97// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index, 98// CHECK-SAME: %[[RANDOM:[a-zA-Z0-9]*]]: index, 99// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1 100func.func @hoist_vector_transfer_pairs_disjoint( 101 %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, 102 %memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index, %lb : index, %ub : index, 103 %step: index, %random_index : index, %cmp: i1) { 104 %c0 = arith.constant 0 : index 105 %c1 = arith.constant 1 : index 106 %c3 = arith.constant 3 : index 107 %cst = arith.constant 0.0 : f32 108 109// CHECK: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32> 110// CHECK: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32> 111// CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32> 112// CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32> 113// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> 114// CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { 115// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> 116// CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { 117// CHECK: vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32> 118// CHECK: vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32> 119// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> 120// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> 121// CHECK: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32> 122// CHECK: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32> 123// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> 124// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> 125// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> 126// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> 127// CHECK: vector.transfer_write %{{.*}}, %[[MEMREF1]]{{.*}} : vector<2xf32>, memref<?x?xf32> 128// CHECK: vector.transfer_write %{{.*}}, %[[MEMREF1]]{{.*}} : vector<2xf32>, memref<?x?xf32> 129// CHECK: scf.yield {{.*}} : vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> 130// CHECK: } 131// CHECK: scf.yield {{.*}} : vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> 132// CHECK: } 133// CHECK: vector.transfer_write %{{.*}}, %[[MEMREF3]]{{.*}} : vector<4xf32>, memref<?x?xf32> 134// CHECK: vector.transfer_write %{{.*}}, %[[MEMREF3]]{{.*}} : vector<4xf32>, memref<?x?xf32> 135// CHECK: vector.transfer_write %{{.*}}, %[[MEMREF2]]{{.*}} : vector<3xf32>, memref<?x?xf32> 136// CHECK: vector.transfer_write %{{.*}}, %[[MEMREF2]]{{.*}} : vector<3xf32>, memref<?x?xf32> 137 scf.for %i = %lb to %ub step %step { 138 scf.for %j = %lb to %ub step %step { 139 %r00 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<2xf32> 140 %r01 = vector.transfer_read %memref1[%c0, %c1], %cst: memref<?x?xf32>, vector<2xf32> 141 %r20 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32> 142 %r21 = vector.transfer_read %memref2[%c0, %c3], %cst: memref<?x?xf32>, vector<3xf32> 143 %r30 = vector.transfer_read %memref3[%c0, %random_index], %cst: memref<?x?xf32>, vector<4xf32> 144 %r31 = vector.transfer_read %memref3[%c1, %random_index], %cst: memref<?x?xf32>, vector<4xf32> 145 %r10 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32> 146 %r11 = vector.transfer_read %memref0[%random_index, %random_index], %cst: memref<?x?xf32>, vector<2xf32> 147 %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32> 148 %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32> 149 %u20 = "some_use"(%r20) : (vector<3xf32>) -> vector<3xf32> 150 %u21 = "some_use"(%r21) : (vector<3xf32>) -> vector<3xf32> 151 %u30 = "some_use"(%r30) : (vector<4xf32>) -> vector<4xf32> 152 %u31 = "some_use"(%r31) : (vector<4xf32>) -> vector<4xf32> 153 %u10 = "some_use"(%r10) : (vector<2xf32>) -> vector<2xf32> 154 %u11 = "some_use"(%r11) : (vector<2xf32>) -> vector<2xf32> 155 vector.transfer_write %u00, %memref1[%c0, %c0] : vector<2xf32>, memref<?x?xf32> 156 vector.transfer_write %u01, %memref1[%c0, %c1] : vector<2xf32>, memref<?x?xf32> 157 vector.transfer_write %u20, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32> 158 vector.transfer_write %u21, %memref2[%c0, %c3] : vector<3xf32>, memref<?x?xf32> 159 vector.transfer_write %u30, %memref3[%c0, %random_index] : vector<4xf32>, memref<?x?xf32> 160 vector.transfer_write %u31, %memref3[%c1, %random_index] : vector<4xf32>, memref<?x?xf32> 161 vector.transfer_write %u10, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32> 162 vector.transfer_write %u11, %memref0[%random_index, %random_index] : vector<2xf32>, memref<?x?xf32> 163 } 164 } 165 return 166} 167 168module attributes {transform.with_named_sequence} { 169 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 170 %0 = transform.structured.match ops{["func.func"]} in %arg1 171 : (!transform.any_op) -> !transform.any_op 172 transform.structured.hoist_redundant_vector_transfers %0 173 : (!transform.any_op) -> !transform.any_op 174 transform.yield 175 } 176} 177 178// ----- 179 180// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops( 181// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>, 182// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>, 183// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) { 184// CHECK: %[[C0:.*]] = arith.constant 0 : i32 185// CHECK: affine.for %[[I:.*]] = 0 to 64 { 186// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 { 187// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32> 188// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) { 189// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> 190// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> 191// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32> 192// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32> 193// CHECK: affine.yield %[[T1]] : vector<16xi32> 194// CHECK: } 195// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32> 196// CHECK: } 197// CHECK: } 198func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) { 199 %c0_i32 = arith.constant 0 : i32 200 affine.for %arg3 = 0 to 64 { 201 affine.for %arg4 = 0 to 64 step 16 { 202 affine.for %arg5 = 0 to 64 { 203 %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32> 204 %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> 205 %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> 206 %3 = arith.muli %0, %1 : vector<16xi32> 207 %4 = arith.addi %2, %3 : vector<16xi32> 208 vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32> 209 } 210 } 211 } 212 return 213} 214 215module attributes {transform.with_named_sequence} { 216 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 217 %0 = transform.structured.match ops{["func.func"]} in %arg1 218 : (!transform.any_op) -> !transform.any_op 219 transform.structured.hoist_redundant_vector_transfers %0 220 : (!transform.any_op) -> !transform.any_op 221 transform.yield 222 } 223} 224 225// ----- 226 227// CHECK-LABEL: func.func @hoist_vector_transfer_read( 228// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 229// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index 230// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index 231// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 232// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x64xf32> 233// CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x128xf32> 234// CHECK: %[[CAST:.+]] = memref.cast %[[ALLOC_0]] : memref<32x128xf32> to memref<32x128xf32, strided<[128, 1], 235// CHECK-SAME: offset: ?>> 236// CHECK: %[[D0:.+]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : 237// CHECK-SAME: memref<32x64xf32>, vector<32x64xf32> 238// CHECK: scf.for %[[ARG0:.+]] = %[[C0]] to %[[C1024]] step %[[C128]] { 239// CHECK: %[[D1:.+]] = vector.transfer_read %[[ALLOC_0]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} 240// CHECK-SAME: : memref<32x128xf32>, vector<32x128xf32> 241// CHECK: "some_use"(%[[D0]], %[[D1]], %[[CAST]]) : (vector<32x64xf32>, vector<32x128xf32>, memref<32x128xf32, 242// CHECK-SAME: strided<[128, 1], offset: ?>>) -> () 243// CHECK: } 244// CHECK: memref.dealloc %[[ALLOC]] : memref<32x64xf32> 245// CHECK: return 246func.func @hoist_vector_transfer_read() { 247 %c0 = arith.constant 0 : index 248 %c128 = arith.constant 128 : index 249 %c1024 = arith.constant 1024 : index 250 %cst_2 = arith.constant 0.000000e+00 : f32 251 %memref0 = memref.alloc() : memref<32x64xf32> 252 %memref2 = memref.alloc() : memref<32x128xf32> 253 %subview2 = memref.subview %memref2[%c0, %c0] [32, 128] [1, 1]: memref<32x128xf32> to memref<32x128xf32, strided<[128, 1], offset: ?>> 254 scf.for %arg0 = %c0 to %c1024 step %c128 { 255 %2 = vector.transfer_read %memref2[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<32x128xf32>, vector<32x128xf32> 256 %3 = vector.transfer_read %memref0[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<32x64xf32>, vector<32x64xf32> 257 "some_use"(%3, %2, %subview2) : (vector<32x64xf32>, vector<32x128xf32>, memref<32x128xf32, strided<[128, 1], offset: ?>>) -> () 258 } 259 memref.dealloc %memref0 : memref<32x64xf32> 260 return 261} 262 263module attributes {transform.with_named_sequence} { 264 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 265 %0 = transform.structured.match ops{["func.func"]} in %arg1 266 : (!transform.any_op) -> !transform.any_op 267 transform.structured.hoist_redundant_vector_transfers %0 268 : (!transform.any_op) -> !transform.any_op 269 transform.yield 270 } 271} 272 273// ----- 274 275// The transfers in this test case cannot be hoisted and replaced by a vector 276// iter_arg because they do not match. 277 278// CHECK-LABEL: func.func @non_matching_transfers( 279// CHECK: scf.for {{.*}} { 280// CHECK: vector.transfer_read 281// CHECK: vector.transfer_write 282// CHECK: } 283func.func @non_matching_transfers(%m: memref<6x1x7x32xf32>) { 284 %c0 = arith.constant 0 : index 285 %c1024 = arith.constant 1024 : index 286 %c128 = arith.constant 128 : index 287 %cst = arith.constant dense<5.5> : vector<6x7x32xf32> 288 %cst_0 = arith.constant 0.0 : f32 289 scf.for %iv = %c0 to %c1024 step %c128 { 290 %read = vector.transfer_read %m[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>} : memref<6x1x7x32xf32>, vector<6x7x32xf32> 291 %added = arith.addf %read, %cst : vector<6x7x32xf32> 292 %bc = vector.broadcast %added : vector<6x7x32xf32> to vector<1x6x7x32xf32> 293 %tr = vector.transpose %bc, [1, 0, 2, 3] : vector<1x6x7x32xf32> to vector<6x1x7x32xf32> 294 vector.transfer_write %tr, %m[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<6x1x7x32xf32>, memref<6x1x7x32xf32> 295 } 296 return 297} 298 299module attributes {transform.with_named_sequence} { 300 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 301 %0 = transform.structured.match ops{["func.func"]} in %arg1 302 : (!transform.any_op) -> !transform.any_op 303 transform.structured.hoist_redundant_vector_transfers %0 304 : (!transform.any_op) -> !transform.any_op 305 transform.yield 306 } 307} 308 309// ----- 310 311// CHECK-LABEL: func.func @no_hoisting_unknown_bound_loop 312func.func @no_hoisting_unknown_bound_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) { 313 %c0_i32 = arith.constant 0 : i32 314 %c0 = arith.constant 0 : index 315 %c1 = arith.constant 1 : index 316 317 // %lb and %ub are unbounded, so do not hoist. 318 // CHECK: scf.for {{.*}} { 319 // CHECK-NEXT: vector.transfer_read 320 // CHECK-NEXT: "test.some_use" 321 scf.for %arg2 = %lb to %ub step %c1 { 322 %read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32> 323 "test.some_use"(%read) : (vector<4xi32>) ->() 324 } 325 return 326} 327 328module attributes {transform.with_named_sequence} { 329 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 330 %0 = transform.structured.match ops{["func.func"]} in %arg1 331 : (!transform.any_op) -> !transform.any_op 332 transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip } 333 : (!transform.any_op) -> !transform.any_op 334 transform.yield 335 } 336} 337 338// ----- 339 340// CHECK-LABEL: func.func @no_hoisting_possibly_zero_trip_loop 341func.func @no_hoisting_possibly_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) { 342 %c0_i32 = arith.constant 0 : i32 343 %c0 = arith.constant 0 : index 344 %c1 = arith.constant 1 : index 345 346 // %lb_0 is in range [%lb, 8], and %ub_0 is in range [4, %ub]. 347 // Since %lb_0 could be greater than %ub_0, do not hoist. 348 %lb_0 = affine.min affine_map<(d0) -> (d0, 8)>(%lb) 349 %ub_0 = affine.max affine_map<(d0) -> (d0, 4)>(%ub) 350 351 // CHECK: scf.for {{.*}} { 352 // CHECK-NEXT: vector.transfer_read 353 // CHECK-NEXT: "test.some_use" 354 scf.for %arg2 = %lb_0 to %ub_0 step %c1 { 355 %read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32> 356 "test.some_use"(%read) : (vector<4xi32>) ->() 357 } 358 return 359} 360 361module attributes {transform.with_named_sequence} { 362 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 363 %0 = transform.structured.match ops{["func.func"]} in %arg1 364 : (!transform.any_op) -> !transform.any_op 365 transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip } 366 : (!transform.any_op) -> !transform.any_op 367 transform.yield 368 } 369} 370 371// ----- 372 373// CHECK-LABEL: func.func @no_hoisting_possibly_zero_trip_loop_eq_lb_and_ub 374func.func @no_hoisting_possibly_zero_trip_loop_eq_lb_and_ub(%memref0: memref<20xi32>, %lb: index, %ub: index) { 375 %c0_i32 = arith.constant 0 : i32 376 %c0 = arith.constant 0 : index 377 %c1 = arith.constant 1 : index 378 379 // %lb_0 is in range [%lb, 8], and %ub_0 is in range [8, %ub]. 380 // Since %lb_0 could be equal to %ub_0, do not hoist. 381 %lb_0 = affine.min affine_map<(d0) -> (d0, 8)>(%lb) 382 %ub_0 = affine.max affine_map<(d0) -> (d0, 8)>(%ub) 383 384 // CHECK: scf.for {{.*}} { 385 // CHECK-NEXT: vector.transfer_read 386 // CHECK-NEXT: "test.some_use" 387 scf.for %arg2 = %lb_0 to %ub_0 step %c1 { 388 %read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32> 389 "test.some_use"(%read) : (vector<4xi32>) ->() 390 } 391 return 392} 393 394module attributes {transform.with_named_sequence} { 395 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 396 %0 = transform.structured.match ops{["func.func"]} in %arg1 397 : (!transform.any_op) -> !transform.any_op 398 transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip } 399 : (!transform.any_op) -> !transform.any_op 400 transform.yield 401 } 402} 403 404// ----- 405 406// CHECK-LABEL: func.func @hoisting_non_zero_trip_loop 407func.func @hoisting_non_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) { 408 %c0_i32 = arith.constant 0 : i32 409 %c0 = arith.constant 0 : index 410 %c1 = arith.constant 1 : index 411 412 // %lb_0 is in range [%lb, 4], and %ub_0 is in range [8, %ub]. 413 // Since %lb_0 is guaranteed to be less than %ub_0, hoisting is possible. 414 %lb_0 = affine.min affine_map<(d0) -> (d0, 4)>(%lb) 415 %ub_0 = affine.max affine_map<(d0) -> (d0, 8)>(%ub) 416 417 // CHECK: vector.transfer_read 418 // CHECK: scf.for {{.*}} { 419 // CHECK-NEXT: "test.some_use" 420 scf.for %arg2 = %lb_0 to %ub_0 step %c1 { 421 %read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32> 422 "test.some_use"(%read) : (vector<4xi32>) ->() 423 } 424 return 425} 426 427module attributes {transform.with_named_sequence} { 428 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 429 %0 = transform.structured.match ops{["func.func"]} in %arg1 430 : (!transform.any_op) -> !transform.any_op 431 transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip } 432 : (!transform.any_op) -> !transform.any_op 433 transform.yield 434 } 435} 436 437// ----- 438 439// Regression test - `vector.transfer_read` below should not be hoisted. 440// Indeed, %collapse_shape (written to by `vector.transfer_write`) and %alloca 441// (read by `vector.transfer_read`) alias. 442 443// CHECK-LABEL: func.func @no_hoisting_collapse_shape 444// CHECK: scf.for {{.*}} { 445// CHECK: vector.transfer_write {{.*}} : vector<4xi32>, memref<4xi32> 446// CHECK-NEXT: vector.transfer_read {{.*}} : memref<1x4x1xi32>, vector<1x4x1xi32> 447// CHECK-NEXT: vector.transfer_write {{.*}} : vector<1x4x1xi32>, memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>> 448// CHECK-NEXT: } 449 450func.func @no_hoisting_collapse_shape(%in_0: memref<1x20x1xi32>, %1: memref<9x1xi32>, %vec: vector<4xi32>) { 451 %c0_i32 = arith.constant 0 : i32 452 %c0 = arith.constant 0 : index 453 %c4 = arith.constant 4 : index 454 %c20 = arith.constant 20 : index 455 %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x4x1xi32> 456 scf.for %arg0 = %c0 to %c20 step %c4 { 457 %subview = memref.subview %in_0[0, %arg0, 0] [1, 4, 1] [1, 1, 1] : memref<1x20x1xi32> to memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>> 458 %collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x4x1xi32> into memref<4xi32> 459 vector.transfer_write %vec, %collapse_shape[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32> 460 %read = vector.transfer_read %alloca[%c0, %c0, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x4x1xi32>, vector<1x4x1xi32> 461 vector.transfer_write %read, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>> 462 } 463 return 464} 465 466module attributes {transform.with_named_sequence} { 467 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 468 %0 = transform.structured.match ops{["func.func"]} in %arg1 469 : (!transform.any_op) -> !transform.any_op 470 transform.structured.hoist_redundant_vector_transfers %0 471 : (!transform.any_op) -> !transform.any_op 472 transform.yield 473 } 474} 475 476// ----- 477 478// Regression test - `vector.transfer_read` below should not be hoisted. 479// Indeed, %collapse_shape (read by `vector.transfer_read`) and %alloca 480// (written to by `vector.transfer_write`) alias. 481 482// CHECK-LABEL: func.func @no_hoisting_collapse_shape_2 483// CHECK: scf.for {{.*}} { 484// CHECK: vector.transfer_write 485// CHECK: vector.transfer_read 486 487func.func @no_hoisting_collapse_shape_2(%vec: vector<1x12x1xi32>) { 488 %c0_i32 = arith.constant 0 : i32 489 %c0 = arith.constant 0 : index 490 %c4 = arith.constant 4 : index 491 %c20 = arith.constant 20 : index 492 %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x12x1xi32> 493 scf.for %arg0 = %c0 to %c20 step %c4 { 494 %collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x12x1xi32> into memref<12xi32> 495 vector.transfer_write %vec, %alloca[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x12x1xi32>, memref<1x12x1xi32> 496 %read = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<12xi32>, vector<12xi32> 497 "test.some_use"(%read) : (vector<12xi32>) ->() 498 } 499 return 500} 501 502module attributes {transform.with_named_sequence} { 503 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 504 %0 = transform.structured.match ops{["func.func"]} in %arg1 505 : (!transform.any_op) -> !transform.any_op 506 transform.structured.hoist_redundant_vector_transfers %0 507 : (!transform.any_op) -> !transform.any_op 508 transform.yield 509 } 510} 511 512// ----- 513 514// Regression test - hoisting the following `vector.transfer_{read|write}` pair 515// would not be safe: 516// %lhs = vector.transfer_read %collapsed_1[%c0] 517// vector.transfer_write %op, %collapsed_1[%c0] 518// That's because the following `vector.transfer_read` reads from the same 519// memory (i.e. `%collapsed_1` and `%collapsed_2` alias): 520// %acc = vector.transfer_read %collapsed_2[%c0] 521 522// CHECK-LABEL: func.func @no_hoisting_write_to_memref 523// CHECK: scf.for {{.*}} { 524// CHECK: vector.transfer_read {{.*}} : memref<2xi32>, vector<1xi32> 525// CHECK-NEXT: vector.transfer_read {{.*}} : memref<2xi32>, vector<1xi32> 526// CHECK-NEXT: vector.outerproduct {{.*}} : vector<1xi32>, i32 527// CHECK-NEXT: vector.transfer_write {{.*}} : vector<1xi32>, memref<2xi32> 528// CHECK-NEXT: } 529 530func.func @no_hoisting_write_to_memref(%rhs: i32, %arg1: vector<1xi32>) { 531 %c0_i32 = arith.constant 0 : i32 532 %c0 = arith.constant 0 : index 533 %c1 = arith.constant 1 : index 534 %c4 = arith.constant 4 : index 535 %c20 = arith.constant 20 : index 536 %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x1x2xi32> 537 %cast = memref.cast %alloca : memref<1x1x2xi32> to memref<1x1x2xi32> 538 %collapsed_1 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> 539 scf.for %_ = %c0 to %c20 step %c4 { 540 %collapsed_2 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> 541 %lhs = vector.transfer_read %collapsed_1[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> 542 %acc = vector.transfer_read %collapsed_2[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> 543 %op = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<1xi32>, i32 544 vector.transfer_write %op, %collapsed_1[%c0] {in_bounds = [true]} : vector<1xi32>, memref<2xi32> 545 } 546 return 547} 548 549module attributes {transform.with_named_sequence} { 550 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 551 %0 = transform.structured.match ops{["func.func"]} in %arg1 552 : (!transform.any_op) -> !transform.any_op 553 transform.structured.hoist_redundant_vector_transfers %0 554 : (!transform.any_op) -> !transform.any_op 555 transform.yield 556 } 557} 558 559// ----- 560 561// Test that we can hoist out 1-D read-write pairs whose indices are dynamic values. 562 563// CHECK: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 1)> 564// CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)> 565 566// CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic 567// CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index) 568 569// CHECK: %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]] 570// CHECK: %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]] 571// CHECK: %2 = vector.transfer_read %[[BUFFER]][%[[I0]], %[[I0]]] 572// CHECK: %3 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[I0]]] 573// CHECK: %4 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]] 574// CHECK-COUNT-2: scf.for %{{.+}} = {{.+}} -> (vector<4xf32>, vector<4xf32>, vector<4xf32>) 575// CHECK-COUNT-3: "some_use" 576// CHECK-COUNT-2: scf.yield {{.+}} : vector<4xf32>, vector<4xf32>, vector<4xf32> 577// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]] 578// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[I0]]] 579// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]] 580 581func.func @hoist_vector_transfer_pairs_disjoint_dynamic( 582 %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) { 583 %cst = arith.constant 0.0 : f32 584 %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0) 585 %i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0) 586 587 scf.for %i = %lb to %ub step %step { 588 scf.for %j = %lb to %ub step %step { 589 %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32> 590 // Disjoint leading dim 591 %r1 = vector.transfer_read %buffer[%i1, %i0], %cst: memref<?x?xf32>, vector<4xf32> 592 // Non-overlap trailing dim 593 %r2 = vector.transfer_read %buffer[%i1, %i2], %cst: memref<?x?xf32>, vector<4xf32> 594 %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32> 595 %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32> 596 %u2 = "some_use"(%r2) : (vector<4xf32>) -> vector<4xf32> 597 vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32> 598 vector.transfer_write %u1, %buffer[%i1, %i0] : vector<4xf32>, memref<?x?xf32> 599 vector.transfer_write %u2, %buffer[%i1, %i2] : vector<4xf32>, memref<?x?xf32> 600 } 601 } 602 return 603} 604 605module attributes {transform.with_named_sequence} { 606 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 607 %0 = transform.structured.match ops{["func.func"]} in %arg1 608 : (!transform.any_op) -> !transform.any_op 609 transform.structured.hoist_redundant_vector_transfers %0 610 : (!transform.any_op) -> !transform.any_op 611 transform.yield 612 } 613} 614 615// ----- 616 617// Test that we cannot hoist out read-write pairs whose indices are overlapping. 618 619// CHECK-LABEL: func.func @hoist_vector_transfer_pairs_overlapping_dynamic 620// CHECK-COUNT-2: scf.for 621// CHECK-COUNT-2: vector.transfer_read 622// CHECK-COUNT-2: vector.transfer_write 623 624func.func @hoist_vector_transfer_pairs_overlapping_dynamic( 625 %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) { 626 %cst = arith.constant 0.0 : f32 627 %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0) 628 629 scf.for %i = %lb to %ub step %step { 630 scf.for %j = %lb to %ub step %step { 631 %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32> 632 // Overlapping range with the above 633 %r1 = vector.transfer_read %buffer[%i0, %i1], %cst: memref<?x?xf32>, vector<4xf32> 634 %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32> 635 %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32> 636 vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32> 637 vector.transfer_write %u1, %buffer[%i0, %i1] : vector<4xf32>, memref<?x?xf32> 638 } 639 } 640 return 641} 642 643module attributes {transform.with_named_sequence} { 644 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 645 %0 = transform.structured.match ops{["func.func"]} in %arg1 646 : (!transform.any_op) -> !transform.any_op 647 transform.structured.hoist_redundant_vector_transfers %0 648 : (!transform.any_op) -> !transform.any_op 649 transform.yield 650 } 651} 652 653// ----- 654 655// Test that we can hoist out 2-D read-write pairs whose indices are dynamic values. 656 657// CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic 658// CHECK-COUNT-3: vector.transfer_read 659// CHECK-COUNT-2: %{{.+}}:3 = scf.for {{.+}} -> (vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32>) 660// CHECK-COUNT-2: scf.yield {{.+}} : vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32> 661// CHECK-COUNT-3: vector.transfer_write 662// CHECK: return 663 664func.func @hoist_vector_transfer_pairs_disjoint_dynamic( 665 %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index, %i1 : index) { 666 %cst = arith.constant 0.0 : f32 667 %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1) 668 %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1) 669 %i4 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 16)>(%i1) 670 671 scf.for %i = %lb to %ub step %step { 672 scf.for %j = %lb to %ub step %step { 673 %r0 = vector.transfer_read %buffer[%i0, %i2], %cst: memref<?x?xf32>, vector<16x8xf32> 674 %r1 = vector.transfer_read %buffer[%i0, %i3], %cst: memref<?x?xf32>, vector<16x8xf32> 675 %r2 = vector.transfer_read %buffer[%i0, %i4], %cst: memref<?x?xf32>, vector<16x8xf32> 676 %u0 = "some_use"(%r0) : (vector<16x8xf32>) -> vector<16x8xf32> 677 %u1 = "some_use"(%r1) : (vector<16x8xf32>) -> vector<16x8xf32> 678 %u2 = "some_use"(%r2) : (vector<16x8xf32>) -> vector<16x8xf32> 679 vector.transfer_write %u2, %buffer[%i0, %i4] : vector<16x8xf32>, memref<?x?xf32> 680 vector.transfer_write %u1, %buffer[%i0, %i3] : vector<16x8xf32>, memref<?x?xf32> 681 vector.transfer_write %u0, %buffer[%i0, %i2] : vector<16x8xf32>, memref<?x?xf32> 682 } 683 } 684 return 685} 686 687module attributes {transform.with_named_sequence} { 688 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 689 %0 = transform.structured.match ops{["func.func"]} in %arg1 690 : (!transform.any_op) -> !transform.any_op 691 transform.structured.hoist_redundant_vector_transfers %0 692 : (!transform.any_op) -> !transform.any_op 693 transform.yield 694 } 695} 696 697// ----- 698 699// Test hoisting of vector.extract/vector.broadcast pairs 700 701// CHECK-LABEL: func.func @hoist_vector_broadcasts 702// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> { 703// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32> 704// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} { 705// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32> 706// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32> 707// CHECK-NEXT: } 708// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32> 709// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32> 710 711func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> { 712 %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> { 713 %extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32> 714 %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32> 715 %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32> 716 scf.yield %broadcast : vector<3x4xf32> 717 } 718 return %bcast_vec : vector<3x4xf32> 719} 720 721module attributes {transform.with_named_sequence} { 722 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 723 %0 = transform.structured.match ops{["func.func"]} in %arg1 724 : (!transform.any_op) -> !transform.any_op 725 transform.structured.hoist_redundant_vector_broadcasts %0 726 : (!transform.any_op) -> !transform.any_op 727 transform.yield 728 } 729} 730 731// ----- 732 733// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position 734 735// CHECK-LABEL: func.func @hoist_vector_broadcasts 736// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> { 737// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32> 738// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} { 739// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32> 740// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32> 741// CHECK-NEXT: } 742// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32> 743// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32> 744 745func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> { 746 %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> { 747 %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32> 748 %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32> 749 %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32> 750 scf.yield %broadcast : vector<3x4xf32> 751 } 752 return %bcast_vec : vector<3x4xf32> 753} 754 755module attributes {transform.with_named_sequence} { 756 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 757 %0 = transform.structured.match ops{["func.func"]} in %arg1 758 : (!transform.any_op) -> !transform.any_op 759 transform.structured.hoist_redundant_vector_broadcasts %0 760 : (!transform.any_op) -> !transform.any_op 761 transform.yield 762 } 763} 764 765// ----- 766 767// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args 768 769// CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple 770// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>, 771// CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) { 772// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32> 773// CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32> 774// CHECK-NEXT: %[[LOOP:.+]]:2 = scf.for {{.*}} { 775// CHECK-DAG: %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32> 776// CHECK-DAG: %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32> 777// CHECK-NEXT: scf.yield %[[USE1]], %[[USE2]] : vector<4xf32>, vector<5xf32> 778// CHECK-NEXT: } 779// CHECK-DAG: %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32> 780// CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32> 781// CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32> 782 783func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) { 784 %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) { 785 %extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32> 786 %extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32> 787 %use1 = "some_use1"(%extract1) : (vector<4xf32>) -> vector<4xf32> 788 %use2 = "some_use2"(%extract2) : (vector<5xf32>) -> vector<5xf32> 789 %broadcast1 = vector.broadcast %use1 : vector<4xf32> to vector<3x4xf32> 790 %broadcast2 = vector.broadcast %use2 : vector<5xf32> to vector<3x5xf32> 791 scf.yield %broadcast1, %broadcast2 : vector<3x4xf32>,vector<3x5xf32> 792 } 793 return %bcast_vec#0, %bcast_vec#1 : vector<3x4xf32>, vector<3x5xf32> 794} 795 796module attributes {transform.with_named_sequence} { 797 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 798 %0 = transform.structured.match ops{["func.func"]} in %arg1 799 : (!transform.any_op) -> !transform.any_op 800 transform.structured.hoist_redundant_vector_broadcasts %0 801 : (!transform.any_op) -> !transform.any_op 802 transform.yield 803 } 804} 805