1// RUN: mlir-opt %s -split-input-file -loop-invariant-subset-hoisting | FileCheck %s 2 3// CHECK-LABEL: func @hoist_matching_extract_insert( 4// CHECK-SAME: %[[arg:.*]]: tensor<?xf32> 5func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> { 6 %lb = "test.foo"() : () -> (index) 7 %ub = "test.foo"() : () -> (index) 8 %step = "test.foo"() : () -> (index) 9 10 %c0 = arith.constant 0 : index 11 %c1 = arith.constant 1 : index 12 %add = arith.addi %c0, %c1 : index 13 %sub = arith.subi %add, %c1 : index 14 15 // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]] 16 // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]]) 17 %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { 18 // CHECK: tensor.extract_slice %[[t]][9] [5] [1] 19 %standalone = tensor.extract_slice %t[9][5][1] : tensor<?xf32> to tensor<5xf32> 20 "test.foo"(%standalone) : (tensor<5xf32>) -> () 21 22 %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32> 23 // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]]) 24 %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) 25 // Obfuscate the IR by inserting at offset %sub instead of 0; both of them 26 // have the same value. 27 %3 = tensor.insert_slice %2 into %t[%sub][5][1] : tensor<5xf32> into tensor<?xf32> 28 // CHECK: scf.yield %[[t]], %[[foo]] 29 scf.yield %3 : tensor<?xf32> 30 } 31 // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0 32 33 // CHECK: return %[[insert]] 34 return %0 : tensor<?xf32> 35} 36 37// ----- 38 39func.func @subset_of_subset(%arg: tensor<?xf32>) -> tensor<?xf32> { 40 %lb = "test.foo"() : () -> (index) 41 %ub = "test.foo"() : () -> (index) 42 %step = "test.foo"() : () -> (index) 43 44 // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]] 45 // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[extract1]] 46 // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted1:.*]] = %[[extract1]], %[[hoisted2:.*]] = %[[extract2]]) 47 %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { 48 %extract1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32> 49 %extract2 = tensor.extract_slice %extract1[1][2][1] : tensor<5xf32> to tensor<2xf32> 50 51 // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted2]]) 52 %2 = "test.foo"(%extract2) : (tensor<2xf32>) -> (tensor<2xf32>) 53 54 %insert1 = tensor.insert_slice %2 into %extract1[1][2][1] : tensor<2xf32> into tensor<5xf32> 55 %insert2 = tensor.insert_slice %insert1 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32> 56 57 // CHECK: scf.yield %[[t]], %[[hoisted1]], %[[foo]] 58 scf.yield %insert2 : tensor<?xf32> 59 } 60 // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#1[1] [2] [1] 61 // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[insert2]] into %[[for]]#0[0] [5] [1] 62 63 // CHECK: return %[[insert1]] 64 return %0 : tensor<?xf32> 65} 66 67// ----- 68 69// CHECK-LABEL: func @hoist_matching_chain( 70// CHECK-SAME: %[[arg:.*]]: tensor<?xf32> 71func.func @hoist_matching_chain(%arg: tensor<?xf32>) -> tensor<?xf32> { 72 %lb = "test.foo"() : () -> (index) 73 %ub = "test.foo"() : () -> (index) 74 %step = "test.foo"() : () -> (index) 75 %sz = "test.foo"() : () -> (index) 76 77 // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][%{{.*}}] [5] [1] 78 // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]][0] [%{{.*}}] [1] 79 // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted2:.*]] = %[[extract2]], %[[hoisted1:.*]] = %[[extract1]]) 80 %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { 81 %1 = tensor.extract_slice %t[0][%sz][1] : tensor<?xf32> to tensor<?xf32> 82 %2 = tensor.extract_slice %t[%sz][5][1] : tensor<?xf32> to tensor<5xf32> 83 // CHECK-DAG: %[[foo1:.*]] = "test.foo"(%[[hoisted1]]) 84 // CHECK-DAG: %[[foo2:.*]] = "test.foo"(%[[hoisted2]]) 85 %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>) 86 %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>) 87 %5 = tensor.insert_slice %foo2 into %t[%sz][5][1] : tensor<5xf32> into tensor<?xf32> 88 %6 = tensor.insert_slice %foo1 into %5[0][%sz][1] : tensor<?xf32> into tensor<?xf32> 89 // CHECK: scf.yield %[[t]], %[[foo2]], %[[foo1]] 90 scf.yield %6 : tensor<?xf32> 91 } 92 // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[0] [%{{.*}}] [1] 93 // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert2]][%{{.*}}] [5] [1] 94 95 // CHECK: return %[[insert1]] 96 return %0 : tensor<?xf32> 97} 98 99// ----- 100 101// CHECK-LABEL: func @do_not_hoist_overlapping_subsets( 102func.func @do_not_hoist_overlapping_subsets(%arg: tensor<?xf32>) -> tensor<?xf32> { 103 %lb = "test.foo"() : () -> (index) 104 %ub = "test.foo"() : () -> (index) 105 %step = "test.foo"() : () -> (index) 106 %sz1 = "test.foo"() : () -> (index) 107 %sz2 = "test.foo"() : () -> (index) 108 109 // CHECK: scf.for 110 %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { 111 // These two slices are potentially overlapping. Do not hoist. 112 // CHECK: tensor.extract_slice 113 // CHECK: tensor.extract_slice 114 %1 = tensor.extract_slice %t[0][%sz1][1] : tensor<?xf32> to tensor<?xf32> 115 %2 = tensor.extract_slice %t[10][%sz2][1] : tensor<?xf32> to tensor<?xf32> 116 // CHECK: "test.foo" 117 // CHECK: "test.foo" 118 %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>) 119 %foo2 = "test.foo"(%2) : (tensor<?xf32>) -> (tensor<?xf32>) 120 // CHECK: tensor.insert_slice 121 // CHECK: tensor.insert_slice 122 %5 = tensor.insert_slice %foo2 into %t[0][%sz1][1] : tensor<?xf32> into tensor<?xf32> 123 %6 = tensor.insert_slice %foo1 into %5[10][%sz2][1] : tensor<?xf32> into tensor<?xf32> 124 // CHECK: scf.yield 125 scf.yield %6 : tensor<?xf32> 126 } 127 128 return %0 : tensor<?xf32> 129} 130 131// ----- 132 133// CHECK-LABEL: func @multiple_yields( 134// CHECK-SAME: %[[arg:.*]]: tensor<?xf32> 135func.func @multiple_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { 136 %lb = "test.foo"() : () -> (index) 137 %ub = "test.foo"() : () -> (index) 138 %step = "test.foo"() : () -> (index) 139 140 // CHECK: %[[extract1:.*]] = tensor.extract_slice 141 // CHECK: %[[extract2:.*]] = tensor.extract_slice 142 // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[arg]], %{{.*}} = %[[arg]], %{{.*}} = %[[extract1]], %{{.*}} = %[[extract2]]) 143 %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg) 144 -> (tensor<?xf32>, tensor<?xf32>) { 145 %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32> 146 %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32> 147 // CHECK: "test.foo" 148 // CHECK: "test.foo" 149 %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) 150 %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>) 151 %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32> 152 %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32> 153 // CHECK: scf.yield 154 scf.yield %5, %6 : tensor<?xf32>, tensor<?xf32> 155 } 156 // CHECK: tensor.insert_slice 157 // CHECK: tensor.insert_slice 158 159 return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32> 160} 161 162// ----- 163 164// CHECK-LABEL: func @do_not_hoist_swapping_yields( 165func.func @do_not_hoist_swapping_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { 166 %lb = "test.foo"() : () -> (index) 167 %ub = "test.foo"() : () -> (index) 168 %step = "test.foo"() : () -> (index) 169 170 // CHECK: scf.for 171 %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg) 172 -> (tensor<?xf32>, tensor<?xf32>) { 173 // CHECK: tensor.extract_slice 174 // CHECK: tensor.extract_slice 175 %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32> 176 %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32> 177 // CHECK: "test.foo" 178 // CHECK: "test.foo" 179 %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) 180 %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>) 181 // CHECK: tensor.insert_slice 182 // CHECK: tensor.insert_slice 183 %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32> 184 %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32> 185 // Swapping yields: do not hoist. 186 // CHECK: scf.yield 187 scf.yield %6, %5 : tensor<?xf32>, tensor<?xf32> 188 } 189 190 return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32> 191} 192 193// ----- 194 195// CHECK-LABEL: func @non_subset_op( 196func.func @non_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> { 197 %lb = "test.foo"() : () -> (index) 198 %ub = "test.foo"() : () -> (index) 199 %step = "test.foo"() : () -> (index) 200 201 // CHECK: scf.for 202 %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { 203 // If any value along the use-def chain from the region iter_arg to the 204 // terminator is used by a non-subset op, no subset op along that chain can 205 // be hoisted. That is because it is unknown which parts of the value are 206 // accessed by the non-subset op. 207 // CHECK: "test.non_subset_op" 208 "test.non_subset_op"(%t) : (tensor<?xf32>) -> () 209 // CHECK: tensor.extract_slice 210 %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32> 211 // CHECK: "test.foo" 212 %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) 213 // CHECK: tensor.insert_slice 214 %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32> 215 // CHECK: scf.yield 216 scf.yield %3 : tensor<?xf32> 217 } 218 219 return %0 : tensor<?xf32> 220} 221 222// ----- 223 224// CHECK-LABEL: func @non_loop_invariant_subset_op( 225func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> { 226 %lb = "test.foo"() : () -> (index) 227 %ub = "test.foo"() : () -> (index) 228 %step = "test.foo"() : () -> (index) 229 230 // CHECK: scf.for 231 %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { 232 // Subset ops that are not loop-invariant cannot be hoisted. 233 // CHECK: tensor.extract_slice 234 %1 = tensor.extract_slice %t[%iv][5][1] : tensor<?xf32> to tensor<5xf32> 235 // CHECK: "test.foo" 236 %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) 237 // CHECK: tensor.insert_slice 238 %3 = tensor.insert_slice %2 into %t[%iv][5][1] : tensor<5xf32> into tensor<?xf32> 239 // CHECK: scf.yield 240 scf.yield %3 : tensor<?xf32> 241 } 242 243 return %0 : tensor<?xf32> 244} 245 246// ----- 247 248// CHECK-LABEL: func @nested_hoisting( 249// CHECK-SAME: %[[arg:.*]]: tensor<?xf32> 250func.func @nested_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> { 251 %lb = "test.foo"() : () -> (index) 252 %ub = "test.foo"() : () -> (index) 253 %step = "test.foo"() : () -> (index) 254 255 // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]][0] [5] [1] 256 // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][5] [5] [1] 257 // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]], %[[hoisted2:.*]] = %[[extract2]]) 258 %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { 259 %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32> 260 // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]]) 261 %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) 262 %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32> 263 // CHECK: %[[for2:.*]]:2 = {{.*}} iter_args(%[[t2:.*]] = %[[t]], %[[hoisted2_nested:.*]] = %[[hoisted2]]) 264 %4 = scf.for %iv2 = %lb to %ub step %step iter_args(%t2 = %3) -> (tensor<?xf32>) { 265 %5 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32> 266 // CHECK: %[[foo2:.*]] = "test.foo"(%[[hoisted2_nested]]) 267 %6 = "test.foo"(%5) : (tensor<5xf32>) -> (tensor<5xf32>) 268 %7 = tensor.insert_slice %6 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32> 269 // CHECK: scf.yield %[[t2]], %[[foo2]] 270 scf.yield %7 : tensor<?xf32> 271 } 272 // CHECK: scf.yield %[[for2]]#0, %[[foo]], %[[for2]]#1 273 scf.yield %4 : tensor<?xf32> 274 } 275 // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[5] [5] [1] 276 // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert]][0] [5] [1] 277 // CHECK: return %[[insert2]] 278 return %0 : tensor<?xf32> 279} 280 281// ----- 282 283// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor 284func.func @hoist_vector_transfer_pairs_tensor( 285 %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>, 286 %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>, 287 %val: index, %lb : index, %ub : index, %step: index) -> 288 (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, 289 tensor<?x?xf32>, tensor<?x?xf32>) { 290 %c0 = arith.constant 0 : index 291 %cst = arith.constant 0.0 : f32 292 293// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<1xf32> 294// CHECK: scf.for {{.*}} iter_args({{.*}}) -> 295// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>) { 296// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<2xf32> 297// CHECK: scf.for {{.*}} iter_args({{.*}}) -> 298// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>) { 299// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<4xf32> 300// CHECK: "test.some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> () 301// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<5xf32> 302// CHECK: "test.some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> 303// CHECK: "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> 304// CHECK: "test.some_use"(%{{.*}}) : (tensor<?x?xf32>) -> vector<3xf32> 305// CHECK: "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> 306// CHECK: "test.some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32> 307// CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, tensor<?x?xf32> 308// CHECK: vector.transfer_write %{{.*}} : vector<4xf32>, tensor<?x?xf32> 309// CHECK: vector.transfer_write %{{.*}} : vector<5xf32>, tensor<?x?xf32> 310// CHECK: "test.some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> () 311// CHECK: scf.yield {{.*}} : 312// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32> 313// CHECK: } 314// CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, tensor<?x?xf32> 315// CHECK: scf.yield {{.*}} : 316// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32> 317// CHECK: } 318// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, tensor<?x?xf32> 319 %0:6 = scf.for %i = %lb to %ub step %step 320 iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2, 321 %arg3 = %tensor3, %arg4 = %tensor4, %arg5 = %tensor5) 322 -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, 323 tensor<?x?xf32>, tensor<?x?xf32>) { 324 %1:6 = scf.for %j = %lb to %ub step %step 325 iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2, 326 %arg9 = %arg3, %arg10 = %arg4, %arg11 = %arg5) 327 -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, 328 tensor<?x?xf32>, tensor<?x?xf32>) { 329 %r0 = vector.transfer_read %arg7[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32> 330 %r1 = vector.transfer_read %arg6[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32> 331 %r3 = vector.transfer_read %arg9[%c0, %c0], %cst: tensor<?x?xf32>, vector<4xf32> 332 "test.some_crippling_use"(%arg10) : (tensor<?x?xf32>) -> () 333 %r4 = vector.transfer_read %arg10[%c0, %c0], %cst: tensor<?x?xf32>, vector<5xf32> 334 %r5 = vector.transfer_read %arg11[%c0, %c0], %cst: tensor<?x?xf32>, vector<6xf32> 335 "test.some_crippling_use"(%arg11) : (tensor<?x?xf32>) -> () 336 %u0 = "test.some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> 337 %u1 = "test.some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> 338 %u2 = "test.some_use"(%arg8) : (tensor<?x?xf32>) -> vector<3xf32> 339 %u3 = "test.some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> 340 %u4 = "test.some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> 341 %u5 = "test.some_use"(%r5) : (vector<6xf32>) -> vector<6xf32> 342 %w1 = vector.transfer_write %u0, %arg7[%c0, %c0] : vector<1xf32>, tensor<?x?xf32> 343 %w0 = vector.transfer_write %u1, %arg6[%i, %i] : vector<2xf32>, tensor<?x?xf32> 344 %w2 = vector.transfer_write %u2, %arg8[%c0, %c0] : vector<3xf32>, tensor<?x?xf32> 345 %w3 = vector.transfer_write %u3, %arg9[%c0, %c0] : vector<4xf32>, tensor<?x?xf32> 346 %w4 = vector.transfer_write %u4, %arg10[%c0, %c0] : vector<5xf32>, tensor<?x?xf32> 347 %w5 = vector.transfer_write %u5, %arg11[%c0, %c0] : vector<6xf32>, tensor<?x?xf32> 348 "test.some_crippling_use"(%w3) : (tensor<?x?xf32>) -> () 349 scf.yield %w0, %w1, %w2, %w3, %w4, %w5 : 350 tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, 351 tensor<?x?xf32>, tensor<?x?xf32> 352 } 353 scf.yield %1#0, %1#1, %1#2, %1#3, %1#4, %1#5 : 354 tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, 355 tensor<?x?xf32>, tensor<?x?xf32> 356 } 357 return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : 358 tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, 359 tensor<?x?xf32>, tensor<?x?xf32> 360} 361 362// ----- 363 364// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor( 365// CHECK-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>, 366// CHECK-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>, 367// CHECK-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>, 368// CHECK-SAME: %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>, 369func.func @hoist_vector_transfer_pairs_disjoint_tensor( 370 %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, 371 %tensor2: tensor<?x?xf32>, %tensor3: tensor<?x?xf32>, 372 %val: index, %lb : index, %ub : index, %step: index, 373 %random_index : index) -> 374 (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) { 375 %c0 = arith.constant 0 : index 376 %c1 = arith.constant 1 : index 377 %c3 = arith.constant 3 : index 378 %cst = arith.constant 0.0 : f32 379 380// CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32> 381// CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32> 382// CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32> 383// CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32> 384// CHECK: %[[R:.*]]:8 = scf.for {{.*}} iter_args({{.*}}) -> 385// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { 386// CHECK: scf.for {{.*}} iter_args({{.*}}) -> 387// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { 388// CHECK: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32> 389// CHECK: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32> 390// CHECK: "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> 391// CHECK: "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> 392// CHECK: "test.some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32> 393// CHECK: "test.some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32> 394// CHECK: "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> 395// CHECK: "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> 396// CHECK: "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> 397// CHECK: "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> 398// CHECK: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32> 399// CHECK: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32> 400// CHECK: scf.yield {{.*}} : 401// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> 402// CHECK: } 403// CHECK: scf.yield {{.*}} : 404// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> 405// CHECK: } 406// CHECK: %[[TENSOR4:.*]] = vector.transfer_write %[[R]]#7, %[[R]]#3{{.*}} : vector<4xf32>, tensor<?x?xf32> 407// CHECK: vector.transfer_write %[[R]]#6, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor<?x?xf32> 408// CHECK: %[[TENSOR5:.*]] = vector.transfer_write %[[R]]#5, %[[R]]#2{{.*}} : vector<3xf32>, tensor<?x?xf32> 409// CHECK: vector.transfer_write %[[R]]#4, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor<?x?xf32> 410 %0:4 = scf.for %i = %lb to %ub step %step 411 iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2, 412 %arg3 = %tensor3) 413 -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) { 414 %1:4 = scf.for %j = %lb to %ub step %step 415 iter_args(%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2, 416 %arg7 = %arg3) 417 -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) { 418 %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32> 419 %r01 = vector.transfer_read %arg5[%c0, %c1], %cst: tensor<?x?xf32>, vector<2xf32> 420 %r20 = vector.transfer_read %arg6[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32> 421 %r21 = vector.transfer_read %arg6[%c0, %c3], %cst: tensor<?x?xf32>, vector<3xf32> 422 %r30 = vector.transfer_read %arg7[%c0, %random_index], %cst: tensor<?x?xf32>, vector<4xf32> 423 %r31 = vector.transfer_read %arg7[%c1, %random_index], %cst: tensor<?x?xf32>, vector<4xf32> 424 %r10 = vector.transfer_read %arg4[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32> 425 %r11 = vector.transfer_read %arg4[%random_index, %random_index], %cst: tensor<?x?xf32>, vector<2xf32> 426 %u00 = "test.some_use"(%r00) : (vector<2xf32>) -> vector<2xf32> 427 %u01 = "test.some_use"(%r01) : (vector<2xf32>) -> vector<2xf32> 428 %u20 = "test.some_use"(%r20) : (vector<3xf32>) -> vector<3xf32> 429 %u21 = "test.some_use"(%r21) : (vector<3xf32>) -> vector<3xf32> 430 %u30 = "test.some_use"(%r30) : (vector<4xf32>) -> vector<4xf32> 431 %u31 = "test.some_use"(%r31) : (vector<4xf32>) -> vector<4xf32> 432 %u10 = "test.some_use"(%r10) : (vector<2xf32>) -> vector<2xf32> 433 %u11 = "test.some_use"(%r11) : (vector<2xf32>) -> vector<2xf32> 434 %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32> 435 %w11 = vector.transfer_write %u01, %w10[%c0, %c1] : vector<2xf32>, tensor<?x?xf32> 436 %w20 = vector.transfer_write %u20, %arg6[%c0, %c0] : vector<3xf32>, tensor<?x?xf32> 437 %w21 = vector.transfer_write %u21, %w20[%c0, %c3] : vector<3xf32>, tensor<?x?xf32> 438 %w30 = vector.transfer_write %u30, %arg7[%c0, %random_index] : vector<4xf32>, tensor<?x?xf32> 439 %w31 = vector.transfer_write %u31, %w30[%c1, %random_index] : vector<4xf32>, tensor<?x?xf32> 440 %w00 = vector.transfer_write %u10, %arg4[%i, %i] : vector<2xf32>, tensor<?x?xf32> 441 %w01 = vector.transfer_write %u11, %w00[%random_index, %random_index] : vector<2xf32>, tensor<?x?xf32> 442 scf.yield %w01, %w11, %w21, %w31 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32> 443 } 444 scf.yield %1#0, %1#1, %1#2, %1#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32> 445 } 446 return %0#0, %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32> 447} 448 449// ----- 450 451// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices 452// CHECK-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>, 453// CHECK-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>, 454// CHECK-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>, 455// CHECK-SAME: %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>, 456// CHECK-SAME: %[[TENSOR4:[a-zA-Z0-9]*]]: tensor<?x?xf32>, 457// CHECK-SAME: %[[TENSOR5:[a-zA-Z0-9]*]]: tensor<?x?xf32> 458func.func @hoist_vector_transfer_pairs_tensor_and_slices( 459 %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>, 460 %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>, 461 %val: index, %lb : index, %ub : index, %step: index) -> 462 ( 463 tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>//, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32> 464 ) { 465 %c0 = arith.constant 0 : index 466 %cst = arith.constant 0.0 : f32 467 468 // CHECK: scf.for %[[I:.*]] = {{.*}} iter_args( 469 // CHECK-SAME: %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]], 470 // CHECK-SAME: %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]], 471 // CHECK-SAME: %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]] 472 // CHECK-SAME: ) -> 473 // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32> 474 %0:3 = scf.for %i = %lb to %ub step %step 475 iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2) 476 -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) { 477 478 // Hoisted 479 // CHECK: %[[ST0:.*]] = tensor.extract_slice %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<?x?xf32> to tensor<?x?xf32> 480 // CHECK: %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor<?x?xf32>, vector<1xf32> 481 482 // CHECK: %[[R:.*]]:5 = scf.for %[[J:.*]] = {{.*}} iter_args( 483 // CHECK-SAME: %[[TENSOR0_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR0_ARG]] 484 // CHECK-SAME: %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]] 485 // CHECK-SAME: %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]] 486 // CHECK-SAME: %[[ST0_ARG_L2:[0-9a-zA-Z]+]] = %[[ST0]] 487 // CHECK-SAME: %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]] 488 // CHECK-SAME: ) -> 489 // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>) 490 %1:3 = scf.for %j = %lb to %ub step %step 491 iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2) 492 -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) { 493 // Hoists. 494 %st0 = tensor.extract_slice %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32> 495 %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32> 496 497 // CHECK: %[[ST1:.*]] = tensor.extract_slice %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32> 498 // CHECK: %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor<?x?xf32>, vector<2xf32> 499 // Does not hoist (slice depends on %j) 500 %st1 = tensor.extract_slice %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32> 501 %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32> 502 503 // CHECK: %[[ST2:.*]] = tensor.extract_slice %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32> 504 // CHECK: %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor<?x?xf32>, vector<3xf32> 505 // Does not hoist, 2 slice %arg8. 506 %st2 = tensor.extract_slice %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32> 507 %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32> 508 509 // CHECK: %[[U0:.*]] = "test.some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32> 510 // CHECK: %[[U1:.*]] = "test.some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32> 511 // CHECK: %[[U2:.*]] = "test.some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32> 512 %u0 = "test.some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> 513 %u1 = "test.some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> 514 %u2 = "test.some_use"(%r2) : (vector<3xf32>) -> vector<3xf32> 515 516 // Hoists 517 %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor<?x?xf32> 518 519 // CHECK-DAG: %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor<?x?xf32> 520 // Does not hoist (associated slice depends on %j). 521 %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor<?x?xf32> 522 523 // CHECK-DAG: %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor<?x?xf32> 524 // Does not hoist, 2 slice / insert_slice for %arg8. 525 %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor<?x?xf32> 526 527 // Hoists. 528 %sti0 = tensor.insert_slice %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32> 529 530 // CHECK-DAG: tensor.insert_slice %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32> 531 // Does not hoist (depends on %j). 532 %sti1 = tensor.insert_slice %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32> 533 534 // CHECK-DAG: tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32> 535 // Does not hoist, 2 slice / insert_slice for %arg8. 536 %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32> 537 // Extract with a different stride to make sure we cannot fold this extract with the above insert. 538 %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<?x?xf32> to tensor<?x?xf32> 539 %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32> 540 541 // CHECK: scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32> 542 // CHECK: } 543 scf.yield %sti0, %sti1, %sti22: 544 tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32> 545 } 546 547 // Hoisted 548 // CHECK: %[[STI0:.*]] = vector.transfer_write %[[R]]#4, %[[R]]#3{{.*}} : vector<1xf32>, tensor<?x?xf32> 549 // CHECK: tensor.insert_slice %[[STI0]] into %[[R]]#0[%[[I]], %[[I]]]{{.*}} : tensor<?x?xf32> into tensor<?x?xf32> 550 551 // CHECK: scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32> 552 scf.yield %1#0, %1#1, %1#2 : 553 tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32> 554 555 // CHECK: } 556 } 557 return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32> 558} 559 560// ----- 561 562// CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor( 563// CHECK-SAME: %[[T:.*]]: tensor<?x?xf32>, 564// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 565// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index 566// CHECK-DAG: %[[R0:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C0]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32> 567// CHECK-DAG: %[[R1:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C3]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32> 568// CHECK: %[[F:.*]]:3 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[TL:.*]] = %[[T]], %[[R2:.*]] = %[[R0]], %[[R3:.*]] = %[[R1]]) -> (tensor<?x?xf32>, vector<2xf32>, vector<2xf32>) { 569// CHECK: %[[R4:.*]] = "test.some_use"(%[[R2]]) : (vector<2xf32>) -> vector<2xf32> 570// CHECK: %[[R5:.*]] = "test.some_use"(%[[R3]]) : (vector<2xf32>) -> vector<2xf32> 571// CHECK: scf.yield %[[TL]], %[[R4]], %[[R5]] : tensor<?x?xf32>, vector<2xf32>, vector<2xf32> 572// CHECK: } 573// CHECK: %[[W0:.*]] = vector.transfer_write %[[F]]#2, %[[F]]#0[%[[C0]], %[[C3]]] : vector<2xf32>, tensor<?x?xf32> 574// CHECK: %[[W1:.*]] = vector.transfer_write %[[F]]#1, %[[W0]][%[[C0]], %[[C0]]] : vector<2xf32>, tensor<?x?xf32> 575// CHECK: return %[[W1]] : tensor<?x?xf32> 576func.func @hoist_vector_transfer_write_pairs_disjoint_tensor( 577 %tensor: tensor<?x?xf32>, 578 %val: index, %lb : index, %ub : index, %step: index) -> 579 (tensor<?x?xf32>) { 580 %c0 = arith.constant 0 : index 581 %c1 = arith.constant 1 : index 582 %c3 = arith.constant 3 : index 583 %cst = arith.constant 0.0 : f32 584 %1 = scf.for %j = %lb to %ub step %step iter_args(%arg5 = %tensor) 585 -> (tensor<?x?xf32>) { 586 %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32> 587 %u00 = "test.some_use"(%r00) : (vector<2xf32>) -> vector<2xf32> 588 %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32> 589 590 // Hoist by properly bypassing the disjoint write %w10. 591 %r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor<?x?xf32>, vector<2xf32> 592 %u01 = "test.some_use"(%r01) : (vector<2xf32>) -> vector<2xf32> 593 %w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor<?x?xf32> 594 scf.yield %w11 : tensor<?x?xf32> 595 } 596 return %1 : tensor<?x?xf32> 597} 598