1// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(scf-parallel-loop-fusion))' -split-input-file | FileCheck %s 2 3func.func @fuse_empty_loops() { 4 %c2 = arith.constant 2 : index 5 %c0 = arith.constant 0 : index 6 %c1 = arith.constant 1 : index 7 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 8 scf.reduce 9 } 10 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 11 scf.reduce 12 } 13 return 14} 15// CHECK-LABEL: func @fuse_empty_loops 16// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index 17// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index 18// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index 19// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 20// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { 21// CHECK: scf.reduce 22// CHECK: } 23// CHECK-NOT: scf.parallel 24 25// ----- 26 27func.func @fuse_ops_between(%A: f32, %B: f32) -> f32 { 28 %c2 = arith.constant 2 : index 29 %c0 = arith.constant 0 : index 30 %c1 = arith.constant 1 : index 31 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 32 scf.reduce 33 } 34 %res = arith.addf %A, %B : f32 35 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 36 scf.reduce 37 } 38 return %res : f32 39} 40// CHECK-LABEL: func @fuse_ops_between 41// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index 42// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index 43// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index 44// CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32 45// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 46// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { 47// CHECK: scf.reduce 48// CHECK: } 49// CHECK-NOT: scf.parallel 50 51// ----- 52 53func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { 54 %c2 = arith.constant 2 : index 55 %c0 = arith.constant 0 : index 56 %c1 = arith.constant 1 : index 57 %c1fp = arith.constant 1.0 : f32 58 %sum = memref.alloc() : memref<2x2xf32> 59 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 60 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 61 %sum_elem = arith.addf %B_elem, %c1fp : f32 62 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> 63 scf.reduce 64 } 65 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 66 %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> 67 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 68 %product_elem = arith.mulf %sum_elem, %A_elem : f32 69 memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> 70 scf.reduce 71 } 72 memref.dealloc %sum : memref<2x2xf32> 73 return 74} 75// CHECK-LABEL: func @fuse_two 76// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { 77// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index 78// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index 79// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index 80// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. 81// CHECK: [[SUM:%.*]] = memref.alloc() 82// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 83// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { 84// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] 85// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] 86// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] 87// CHECK-NOT: scf.parallel 88// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] 89// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] 90// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] 91// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] 92// CHECK: scf.reduce 93// CHECK: } 94// CHECK: memref.dealloc [[SUM]] 95 96// ----- 97 98func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { 99 %c2 = arith.constant 2 : index 100 %c0 = arith.constant 0 : index 101 %c1 = arith.constant 1 : index 102 %c1fp = arith.constant 1.0 : f32 103 %c2fp = arith.constant 2.0 : f32 104 %sum = memref.alloc() : memref<2x2xf32> 105 %prod = memref.alloc() : memref<2x2xf32> 106 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 107 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 108 %sum_elem = arith.addf %B_elem, %c1fp : f32 109 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> 110 scf.reduce 111 } 112 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 113 %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> 114 %product_elem = arith.mulf %sum_elem, %c2fp : f32 115 memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32> 116 scf.reduce 117 } 118 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 119 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 120 %res_elem = arith.addf %A_elem, %c2fp : f32 121 memref.store %res_elem, %B[%i, %j] : memref<2x2xf32> 122 } 123 memref.dealloc %sum : memref<2x2xf32> 124 memref.dealloc %prod : memref<2x2xf32> 125 return 126} 127// CHECK-LABEL: func @fuse_three 128// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { 129// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index 130// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index 131// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index 132// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. 133// CHECK-DAG: [[C2FP:%.*]] = arith.constant 2. 134// CHECK: [[SUM:%.*]] = memref.alloc() 135// CHECK: [[PROD:%.*]] = memref.alloc() 136// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 137// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { 138// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] 139// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] 140// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] 141// CHECK-NOT: scf.parallel 142// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] 143// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[C2FP]] 144// CHECK: memref.store [[PRODUCT_ELEM]], [[PROD]]{{\[}}[[I]], [[J]]] 145// CHECK-NOT: scf.parallel 146// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] 147// CHECK: [[RES_ELEM:%.*]] = arith.addf [[A_ELEM]], [[C2FP]] 148// CHECK: memref.store [[RES_ELEM]], [[B]]{{\[}}[[I]], [[J]]] 149// CHECK: scf.reduce 150// CHECK: } 151// CHECK: memref.dealloc [[SUM]] 152// CHECK: memref.dealloc [[PROD]] 153 154// ----- 155 156func.func @do_not_fuse_nested_ploop1() { 157 %c2 = arith.constant 2 : index 158 %c0 = arith.constant 0 : index 159 %c1 = arith.constant 1 : index 160 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 161 scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 162 scf.reduce 163 } 164 scf.reduce 165 } 166 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 167 scf.reduce 168 } 169 return 170} 171// CHECK-LABEL: func @do_not_fuse_nested_ploop1 172// CHECK: scf.parallel 173// CHECK: scf.parallel 174// CHECK: scf.parallel 175 176// ----- 177 178func.func @do_not_fuse_nested_ploop2() { 179 %c2 = arith.constant 2 : index 180 %c0 = arith.constant 0 : index 181 %c1 = arith.constant 1 : index 182 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 183 scf.reduce 184 } 185 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 186 scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 187 scf.reduce 188 } 189 scf.reduce 190 } 191 return 192} 193// CHECK-LABEL: func @do_not_fuse_nested_ploop2 194// CHECK: scf.parallel 195// CHECK: scf.parallel 196// CHECK: scf.parallel 197 198// ----- 199 200func.func @do_not_fuse_loops_unmatching_num_loops() { 201 %c2 = arith.constant 2 : index 202 %c0 = arith.constant 0 : index 203 %c1 = arith.constant 1 : index 204 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 205 scf.reduce 206 } 207 scf.parallel (%i) = (%c0) to (%c2) step (%c1) { 208 scf.reduce 209 } 210 return 211} 212// CHECK-LABEL: func @do_not_fuse_loops_unmatching_num_loops 213// CHECK: scf.parallel 214// CHECK: scf.parallel 215 216// ----- 217 218func.func @do_not_fuse_loops_with_side_effecting_ops_in_between() { 219 %c2 = arith.constant 2 : index 220 %c0 = arith.constant 0 : index 221 %c1 = arith.constant 1 : index 222 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 223 scf.reduce 224 } 225 %buffer = memref.alloc() : memref<2x2xf32> 226 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 227 scf.reduce 228 } 229 return 230} 231// CHECK-LABEL: func @do_not_fuse_loops_with_side_effecting_ops_in_between 232// CHECK: scf.parallel 233// CHECK: scf.parallel 234 235// ----- 236 237func.func @do_not_fuse_loops_unmatching_iteration_space() { 238 %c0 = arith.constant 0 : index 239 %c1 = arith.constant 1 : index 240 %c2 = arith.constant 2 : index 241 %c4 = arith.constant 4 : index 242 scf.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) { 243 scf.reduce 244 } 245 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 246 scf.reduce 247 } 248 return 249} 250// CHECK-LABEL: func @do_not_fuse_loops_unmatching_iteration_space 251// CHECK: scf.parallel 252// CHECK: scf.parallel 253 254// ----- 255 256func.func @do_not_fuse_unmatching_write_read_patterns( 257 %A: memref<2x2xf32>, %B: memref<2x2xf32>, 258 %C: memref<2x2xf32>, %result: memref<2x2xf32>) { 259 %c2 = arith.constant 2 : index 260 %c0 = arith.constant 0 : index 261 %c1 = arith.constant 1 : index 262 %common_buf = memref.alloc() : memref<2x2xf32> 263 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 264 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 265 %C_elem = memref.load %C[%i, %j] : memref<2x2xf32> 266 %sum_elem = arith.addf %B_elem, %C_elem : f32 267 memref.store %sum_elem, %common_buf[%i, %j] : memref<2x2xf32> 268 scf.reduce 269 } 270 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 271 %k = arith.addi %i, %c1 : index 272 %sum_elem = memref.load %common_buf[%k, %j] : memref<2x2xf32> 273 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 274 %product_elem = arith.mulf %sum_elem, %A_elem : f32 275 memref.store %product_elem, %result[%i, %j] : memref<2x2xf32> 276 scf.reduce 277 } 278 memref.dealloc %common_buf : memref<2x2xf32> 279 return 280} 281// CHECK-LABEL: func @do_not_fuse_unmatching_write_read_patterns 282// CHECK: scf.parallel 283// CHECK: scf.parallel 284 285// ----- 286 287func.func @do_not_fuse_unmatching_read_write_patterns( 288 %A: memref<2x2xf32>, %B: memref<2x2xf32>, %common_buf: memref<2x2xf32>) { 289 %c2 = arith.constant 2 : index 290 %c0 = arith.constant 0 : index 291 %c1 = arith.constant 1 : index 292 %sum = memref.alloc() : memref<2x2xf32> 293 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 294 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 295 %C_elem = memref.load %common_buf[%i, %j] : memref<2x2xf32> 296 %sum_elem = arith.addf %B_elem, %C_elem : f32 297 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> 298 scf.reduce 299 } 300 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 301 %k = arith.addi %i, %c1 : index 302 %sum_elem = memref.load %sum[%k, %j] : memref<2x2xf32> 303 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 304 %product_elem = arith.mulf %sum_elem, %A_elem : f32 305 memref.store %product_elem, %common_buf[%j, %i] : memref<2x2xf32> 306 scf.reduce 307 } 308 memref.dealloc %sum : memref<2x2xf32> 309 return 310} 311// CHECK-LABEL: func @do_not_fuse_unmatching_read_write_patterns 312// CHECK: scf.parallel 313// CHECK: scf.parallel 314 315// ----- 316 317func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() { 318 %c2 = arith.constant 2 : index 319 %c0 = arith.constant 0 : index 320 %c1 = arith.constant 1 : index 321 %buffer = memref.alloc() : memref<2x2xf32> 322 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 323 scf.reduce 324 } 325 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 326 %A = memref.subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1] 327 : memref<2x2xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>> 328 %A_elem = memref.load %A[%i, %j] : memref<?x?xf32, strided<[?, ?], offset: ?>> 329 scf.reduce 330 } 331 return 332} 333// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies 334// CHECK: scf.parallel 335// CHECK: scf.parallel 336 337// ----- 338 339func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { 340 %c2 = arith.constant 2 : index 341 %c0 = arith.constant 0 : index 342 %c1 = arith.constant 1 : index 343 %c1fp = arith.constant 1.0 : f32 344 %sum = memref.alloc() : memref<2x2xf32> 345 scf.parallel (%k) = (%c0) to (%c2) step (%c1) { 346 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 347 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 348 %sum_elem = arith.addf %B_elem, %c1fp : f32 349 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> 350 scf.reduce 351 } 352 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 353 %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> 354 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 355 %product_elem = arith.mulf %sum_elem, %A_elem : f32 356 memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> 357 scf.reduce 358 } 359 } 360 memref.dealloc %sum : memref<2x2xf32> 361 return 362} 363// CHECK-LABEL: func @nested_fuse 364// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { 365// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index 366// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index 367// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index 368// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. 369// CHECK: [[SUM:%.*]] = memref.alloc() 370// CHECK: scf.parallel 371// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 372// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { 373// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] 374// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] 375// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] 376// CHECK-NOT: scf.parallel 377// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] 378// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] 379// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] 380// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] 381// CHECK: scf.reduce 382// CHECK: } 383// CHECK: } 384// CHECK: memref.dealloc [[SUM]] 385 386// ----- 387 388func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>, 389 %C: memref<2x2xf32>, %result: memref<2x2xf32>, 390 %sum: memref<2x2xf32>) { 391 %c2 = arith.constant 2 : index 392 %c0 = arith.constant 0 : index 393 %c1 = arith.constant 1 : index 394 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 395 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 396 %C_elem = memref.load %C[%i, %j] : memref<2x2xf32> 397 %sum_elem = arith.addf %B_elem, %C_elem : f32 398 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> 399 scf.reduce 400 } 401 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 402 %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> 403 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 404 %product_elem = arith.mulf %sum_elem, %A_elem : f32 405 memref.store %product_elem, %result[%i, %j] : memref<2x2xf32> 406 scf.reduce 407 } 408 return 409} 410// %sum and %result may alias with other args, do not fuse loops 411// CHECK-LABEL: func @do_not_fuse_alias 412// CHECK: scf.parallel 413// CHECK: scf.parallel 414 415// ----- 416 417func.func @fuse_when_1st_has_multiple_stores( 418 %A: memref<2x2xf32>, %B: memref<2x2xf32>) { 419 %c0 = arith.constant 0 : index 420 %c1 = arith.constant 1 : index 421 %c2 = arith.constant 2 : index 422 %c0fp = arith.constant 0.0 : f32 423 %sum = memref.alloc() : memref<2x2xf32> 424 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 425 memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32> 426 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 427 %sum_elem = arith.addf %B_elem, %B_elem : f32 428 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> 429 scf.reduce 430 } 431 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 432 %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> 433 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 434 %product_elem = arith.mulf %sum_elem, %A_elem : f32 435 memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> 436 scf.reduce 437 } 438 memref.dealloc %sum : memref<2x2xf32> 439 return 440} 441// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores 442// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { 443// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index 444// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index 445// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index 446// CHECK-DAG: [[C0F32:%.*]] = arith.constant 0. 447// CHECK: [[SUM:%.*]] = memref.alloc() 448// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 449// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { 450// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] 451// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]] 452// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] 453// CHECK-NOT: scf.parallel 454// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] 455// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] 456// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf 457// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] 458// CHECK: scf.reduce 459// CHECK: } 460// CHECK: memref.dealloc [[SUM]] 461 462// ----- 463 464func.func @do_not_fuse_multiple_stores_on_diff_indices( 465 %A: memref<2x2xf32>, %B: memref<2x2xf32>) { 466 %c0 = arith.constant 0 : index 467 %c1 = arith.constant 1 : index 468 %c2 = arith.constant 2 : index 469 %c0fp = arith.constant 0.0 : f32 470 %sum = memref.alloc() : memref<2x2xf32> 471 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 472 memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32> 473 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 474 %sum_elem = arith.addf %B_elem, %B_elem : f32 475 memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32> 476 scf.reduce 477 } 478 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 479 %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> 480 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 481 %product_elem = arith.mulf %sum_elem, %A_elem : f32 482 memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> 483 scf.reduce 484 } 485 memref.dealloc %sum : memref<2x2xf32> 486 return 487} 488// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices 489// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { 490// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index 491// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index 492// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index 493// CHECK-DAG: [[C0F32:%.*]] = arith.constant 0. 494// CHECK: [[SUM:%.*]] = memref.alloc() 495// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 496// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { 497// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] 498// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]] 499// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[C0]], [[J]]] 500// CHECK: scf.reduce 501// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 502// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] 503// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] 504// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf 505// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] 506// CHECK: scf.reduce 507// CHECK: } 508// CHECK: memref.dealloc [[SUM]] 509 510// ----- 511 512func.func @fuse_same_indices_by_affine_apply( 513 %A: memref<2x2xf32>, %B: memref<2x2xf32>) { 514 %c0 = arith.constant 0 : index 515 %c1 = arith.constant 1 : index 516 %c2 = arith.constant 2 : index 517 %sum = memref.alloc() : memref<2x3xf32> 518 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 519 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 520 %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j) 521 memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32> 522 scf.reduce 523 } 524 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 525 %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j) 526 %sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32> 527 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 528 %product = arith.mulf %sum_elem, %A_elem : f32 529 memref.store %product, %B[%i, %j] : memref<2x2xf32> 530 scf.reduce 531 } 532 memref.dealloc %sum : memref<2x3xf32> 533 return 534} 535// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)> 536// CHECK-LABEL: fuse_same_indices_by_affine_apply 537// CHECK-SAME: (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>) { 538// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 539// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 540// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 541// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32> 542// CHECK-NEXT: scf.parallel (%[[ARG2:.*]], %[[ARG3:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) { 543// CHECK-NEXT: %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32> 544// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]]) 545// CHECK-NEXT: memref.store %[[S0]], %[[ALLOC]][%[[ARG2]], %[[S1]]] : memref<2x3xf32> 546// CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]]) 547// CHECK-NEXT: %[[S3:.*]] = memref.load %[[ALLOC]][%[[ARG2]], %[[S2]]] : memref<2x3xf32> 548// CHECK-NEXT: %[[S4:.*]] = memref.load %[[ARG0]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32> 549// CHECK-NEXT: %[[S5:.*]] = arith.mulf %[[S3]], %[[S4]] : f32 550// CHECK-NEXT: memref.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32> 551// CHECK-NEXT: scf.reduce 552// CHECK-NEXT: } 553// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32> 554// CHECK-NEXT: return 555 556// ----- 557 558func.func @do_not_fuse_affine_apply_to_non_ind_var( 559 %A: memref<2x2xf32>, %B: memref<2x2xf32>, %OffsetA: index, %OffsetB: index) { 560 %c0 = arith.constant 0 : index 561 %c1 = arith.constant 1 : index 562 %c2 = arith.constant 2 : index 563 %sum = memref.alloc() : memref<2x3xf32> 564 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 565 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 566 %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %OffsetA) 567 memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32> 568 scf.reduce 569 } 570 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 571 %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %OffsetB) 572 %sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32> 573 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 574 %product = arith.mulf %sum_elem, %A_elem : f32 575 memref.store %product, %B[%i, %j] : memref<2x2xf32> 576 scf.reduce 577 } 578 memref.dealloc %sum : memref<2x3xf32> 579 return 580} 581// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)> 582// CHECK-LABEL: do_not_fuse_affine_apply_to_non_ind_var 583// CHECK-SAME: (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) { 584// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 585// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 586// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 587// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32> 588// CHECK-NEXT: scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) { 589// CHECK-NEXT: %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32> 590// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG2]]) 591// CHECK-NEXT: memref.store %[[S0]], %[[ALLOC]][%[[ARG4]], %[[S1]]] : memref<2x3xf32> 592// CHECK-NEXT: scf.reduce 593// CHECK-NEXT: } 594// CHECK-NEXT: scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) { 595// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG3]]) 596// CHECK-NEXT: %[[S1:.*]] = memref.load %[[ALLOC]][%[[ARG4]], %[[S0]]] : memref<2x3xf32> 597// CHECK-NEXT: %[[S2:.*]] = memref.load %[[ARG0]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32> 598// CHECK-NEXT: %[[S3:.*]] = arith.mulf %[[S1]], %[[S2]] : f32 599// CHECK-NEXT: memref.store %[[S3]], %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32> 600// CHECK-NEXT: scf.reduce 601// CHECK-NEXT: } 602// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32> 603// CHECK-NEXT: return 604 605// ----- 606 607func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { 608 %c2 = arith.constant 2 : index 609 %c0 = arith.constant 0 : index 610 %c1 = arith.constant 1 : index 611 %init1 = arith.constant 1.0 : f32 612 %init2 = arith.constant 2.0 : f32 613 %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { 614 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 615 scf.reduce(%A_elem : f32) { 616 ^bb0(%lhs: f32, %rhs: f32): 617 %1 = arith.addf %lhs, %rhs : f32 618 scf.reduce.return %1 : f32 619 } 620 } 621 %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { 622 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 623 scf.reduce(%B_elem : f32) { 624 ^bb0(%lhs: f32, %rhs: f32): 625 %1 = arith.mulf %lhs, %rhs : f32 626 scf.reduce.return %1 : f32 627 } 628 } 629 return %res1, %res2 : f32, f32 630} 631 632// CHECK-LABEL: func @fuse_reductions_two 633// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32) 634// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 635// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 636// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 637// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32 638// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32 639// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) 640// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) 641// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32) 642// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]] 643// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] 644// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) { 645// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): 646// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 647// CHECK: scf.reduce.return %[[R]] : f32 648// CHECK: } 649// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): 650// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 651// CHECK: scf.reduce.return %[[R]] : f32 652// CHECK: } 653// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32 654 655// ----- 656 657func.func @fuse_reductions_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C: memref<2x2xf32>) -> (f32, f32, f32) { 658 %c2 = arith.constant 2 : index 659 %c0 = arith.constant 0 : index 660 %c1 = arith.constant 1 : index 661 %init1 = arith.constant 1.0 : f32 662 %init2 = arith.constant 2.0 : f32 663 %init3 = arith.constant 3.0 : f32 664 %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { 665 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 666 scf.reduce(%A_elem : f32) { 667 ^bb0(%lhs: f32, %rhs: f32): 668 %1 = arith.addf %lhs, %rhs : f32 669 scf.reduce.return %1 : f32 670 } 671 } 672 %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { 673 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 674 scf.reduce(%B_elem : f32) { 675 ^bb0(%lhs: f32, %rhs: f32): 676 %1 = arith.mulf %lhs, %rhs : f32 677 scf.reduce.return %1 : f32 678 } 679 } 680 %res3 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init3) -> f32 { 681 %A_elem = memref.load %C[%i, %j] : memref<2x2xf32> 682 scf.reduce(%A_elem : f32) { 683 ^bb0(%lhs: f32, %rhs: f32): 684 %1 = arith.addf %lhs, %rhs : f32 685 scf.reduce.return %1 : f32 686 } 687 } 688 return %res1, %res2, %res3 : f32, f32, f32 689} 690 691// CHECK-LABEL: func @fuse_reductions_three 692// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32) 693// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 694// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 695// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 696// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32 697// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32 698// CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32 699// CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) 700// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) 701// CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32) 702// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]] 703// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] 704// CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]] 705// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) { 706// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): 707// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 708// CHECK: scf.reduce.return %[[R]] : f32 709// CHECK: } 710// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): 711// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 712// CHECK: scf.reduce.return %[[R]] : f32 713// CHECK: } 714// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): 715// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 716// CHECK: scf.reduce.return %[[R]] : f32 717// CHECK: } 718// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32 719 720// ----- 721 722func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { 723 %c2 = arith.constant 2 : index 724 %c0 = arith.constant 0 : index 725 %c1 = arith.constant 1 : index 726 %init1 = arith.constant 1.0 : f32 727 %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { 728 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 729 scf.reduce(%A_elem : f32) { 730 ^bb0(%lhs: f32, %rhs: f32): 731 %1 = arith.addf %lhs, %rhs : f32 732 scf.reduce.return %1 : f32 733 } 734 } 735 %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 { 736 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 737 scf.reduce(%B_elem : f32) { 738 ^bb0(%lhs: f32, %rhs: f32): 739 %1 = arith.mulf %lhs, %rhs : f32 740 scf.reduce.return %1 : f32 741 } 742 } 743 return %res1, %res2 : f32, f32 744} 745 746// %res1 is used as second scf.parallel arg, cannot fuse 747// CHECK-LABEL: func @reductions_use_res 748// CHECK: scf.parallel 749// CHECK: scf.parallel 750 751// ----- 752 753func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { 754 %c2 = arith.constant 2 : index 755 %c0 = arith.constant 0 : index 756 %c1 = arith.constant 1 : index 757 %init1 = arith.constant 1.0 : f32 758 %init2 = arith.constant 2.0 : f32 759 %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { 760 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 761 scf.reduce(%A_elem : f32) { 762 ^bb0(%lhs: f32, %rhs: f32): 763 %1 = arith.addf %lhs, %rhs : f32 764 scf.reduce.return %1 : f32 765 } 766 } 767 %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { 768 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 769 %sum = arith.addf %B_elem, %res1 : f32 770 scf.reduce(%sum : f32) { 771 ^bb0(%lhs: f32, %rhs: f32): 772 %1 = arith.mulf %lhs, %rhs : f32 773 scf.reduce.return %1 : f32 774 } 775 } 776 return %res1, %res2 : f32, f32 777} 778 779// %res1 is used inside second scf.parallel, cannot fuse 780// CHECK-LABEL: func @reductions_use_res_inside 781// CHECK: scf.parallel 782// CHECK: scf.parallel 783 784// ----- 785 786func.func @reductions_use_res_between(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32, f32) { 787 %c2 = arith.constant 2 : index 788 %c0 = arith.constant 0 : index 789 %c1 = arith.constant 1 : index 790 %init1 = arith.constant 1.0 : f32 791 %init2 = arith.constant 2.0 : f32 792 %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { 793 %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> 794 scf.reduce(%A_elem : f32) { 795 ^bb0(%lhs: f32, %rhs: f32): 796 %1 = arith.addf %lhs, %rhs : f32 797 scf.reduce.return %1 : f32 798 } 799 } 800 %res3 = arith.addf %res1, %init2 : f32 801 %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { 802 %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> 803 scf.reduce(%B_elem : f32) { 804 ^bb0(%lhs: f32, %rhs: f32): 805 %1 = arith.mulf %lhs, %rhs : f32 806 scf.reduce.return %1 : f32 807 } 808 } 809 return %res1, %res2, %res3 : f32, f32, f32 810} 811 812// instruction in between the loops uses the first loop result 813// CHECK-LABEL: func @reductions_use_res_between 814// CHECK: scf.parallel 815// CHECK: scf.parallel 816