1// RUN: mlir-opt -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer fusion-maximal}))' %s | FileCheck %s 2 3// Test fusion of affine nests inside other region-holding ops (scf.for in the 4// test case below). 5 6// CHECK-LABEL: func @fusion_inner_simple 7func.func @fusion_inner_simple(%A : memref<10xf32>) { 8 %cst = arith.constant 0.0 : f32 9 10 affine.for %i = 0 to 100 { 11 %B = memref.alloc() : memref<10xf32> 12 %C = memref.alloc() : memref<10xf32> 13 14 affine.for %j = 0 to 10 { 15 %v = affine.load %A[%j] : memref<10xf32> 16 affine.store %v, %B[%j] : memref<10xf32> 17 } 18 19 affine.for %j = 0 to 10 { 20 %v = affine.load %B[%j] : memref<10xf32> 21 affine.store %v, %C[%j] : memref<10xf32> 22 } 23 } 24 25 // CHECK: affine.for %{{.*}} = 0 to 100 26 // CHECK-NEXT: memref.alloc 27 // CHECK-NEXT: memref.alloc 28 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 29 // CHECK-NOT: affine.for 30 31 return 32} 33 34// CHECK-LABEL: func @fusion_inner_simple_scf 35func.func @fusion_inner_simple_scf(%A : memref<10xf32>) { 36 %c0 = arith.constant 0 : index 37 %c1 = arith.constant 1 : index 38 %c100 = arith.constant 100 : index 39 %cst = arith.constant 0.0 : f32 40 41 scf.for %i = %c0 to %c100 step %c1 { 42 %B = memref.alloc() : memref<10xf32> 43 %C = memref.alloc() : memref<10xf32> 44 45 affine.for %j = 0 to 10 { 46 %v = affine.load %A[%j] : memref<10xf32> 47 affine.store %v, %B[%j] : memref<10xf32> 48 } 49 50 affine.for %j = 0 to 10 { 51 %v = affine.load %B[%j] : memref<10xf32> 52 affine.store %v, %C[%j] : memref<10xf32> 53 } 54 } 55 // CHECK: scf.for 56 // CHECK-NEXT: memref.alloc 57 // CHECK-NEXT: memref.alloc 58 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 59 // CHECK-NOT: affine.for 60 return 61} 62 63// CHECK-LABEL: func @fusion_inner_multiple_nests 64func.func @fusion_inner_multiple_nests() { 65 %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x4xi8> 66 %alloc_10 = memref.alloc() : memref<8x4xi32> 67 affine.for %arg8 = 0 to 4 { 68 %alloc_14 = memref.alloc() : memref<4xi8> 69 %alloc_15 = memref.alloc() : memref<8x4xi8> 70 affine.for %arg9 = 0 to 4 { 71 %0 = affine.load %alloc_5[%arg9, %arg8] : memref<4x4xi8> 72 affine.store %0, %alloc_14[%arg9] : memref<4xi8> 73 } 74 %alloc_16 = memref.alloc() : memref<4xi8> 75 affine.for %arg9 = 0 to 4 { 76 %0 = affine.load %alloc_14[%arg9] : memref<4xi8> 77 affine.store %0, %alloc_16[%arg9] : memref<4xi8> 78 } 79 affine.for %arg9 = 0 to 2 { 80 %0 = affine.load %alloc_15[%arg9 * 4, 0] : memref<8x4xi8> 81 %1 = affine.load %alloc_16[0] : memref<4xi8> 82 %2 = affine.load %alloc_10[%arg9 * 4, %arg8] : memref<8x4xi32> 83 %3 = arith.muli %0, %1 : i8 84 %4 = arith.extsi %3 : i8 to i32 85 %5 = arith.addi %4, %2 : i32 86 affine.store %5, %alloc_10[%arg9 * 4 + 3, %arg8] : memref<8x4xi32> 87 } 88 memref.dealloc %alloc_16 : memref<4xi8> 89 } 90 // CHECK: affine.for %{{.*}} = 0 to 4 { 91 // Everything inside fused into two nests (the second will be DCE'd). 92 // CHECK-NEXT: memref.alloc() : memref<4xi8> 93 // CHECK-NEXT: memref.alloc() : memref<1xi8> 94 // CHECK-NEXT: memref.alloc() : memref<1xi8> 95 // CHECK-NEXT: memref.alloc() : memref<8x4xi8> 96 // CHECK-NEXT: memref.alloc() : memref<4xi8> 97 // CHECK-NEXT: affine.for %{{.*}} = 0 to 2 { 98 // CHECK: } 99 // CHECK: affine.for %{{.*}} = 0 to 4 { 100 // CHECK: } 101 // CHECK-NEXT: memref.dealloc 102 // CHECK-NEXT: } 103 // CHECK-NEXT: return 104 return 105} 106 107// CHECK-LABEL: func @fusion_inside_scf_while 108func.func @fusion_inside_scf_while(%A : memref<10xf32>) { 109 %c0 = arith.constant 0 : index 110 %c1 = arith.constant 1 : index 111 %c100 = arith.constant 100 : index 112 %cst = arith.constant 0.0 : f32 113 114 %0 = scf.while (%arg3 = %cst) : (f32) -> (f32) { 115 %1 = arith.cmpf ult, %arg3, %cst : f32 116 scf.condition(%1) %arg3 : f32 117 } do { 118 ^bb0(%arg5: f32): 119 120 %B = memref.alloc() : memref<10xf32> 121 %C = memref.alloc() : memref<10xf32> 122 123 affine.for %j = 0 to 10 { 124 %v = affine.load %A[%j] : memref<10xf32> 125 affine.store %v, %B[%j] : memref<10xf32> 126 } 127 128 affine.for %j = 0 to 10 { 129 %v = affine.load %B[%j] : memref<10xf32> 130 affine.store %v, %C[%j] : memref<10xf32> 131 } 132 %1 = arith.mulf %arg5, %cst : f32 133 scf.yield %1 : f32 134 } 135 // CHECK: scf.while 136 // CHECK: affine.for %{{.*}} = 0 to 10 137 // CHECK-NOT: affine.for 138 // CHECK: scf.yield 139 return 140} 141 142 143memref.global "private" constant @__constant_10x2xf32 : memref<10x2xf32> = dense<0.000000e+00> 144 145// CHECK-LABEL: func @fusion_inner_long 146func.func @fusion_inner_long(%arg0: memref<10x2xf32>, %arg1: memref<10x10xf32>, %arg2: memref<10x2xf32>, %s: index) { 147 %c0 = arith.constant 0 : index 148 %cst_0 = arith.constant 1.000000e-03 : f32 149 %c9 = arith.constant 9 : index 150 %c10_i32 = arith.constant 10 : i32 151 %c1_i32 = arith.constant 1 : i32 152 %c100_i32 = arith.constant 100 : i32 153 %c0_i32 = arith.constant 0 : i32 154 %0 = memref.get_global @__constant_10x2xf32 : memref<10x2xf32> 155 %1 = scf.for %arg3 = %c0_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %arg0) -> (memref<10x2xf32>) : i32 { 156 %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32> 157 affine.for %arg5 = 0 to 10 { 158 %3 = arith.index_cast %arg5 : index to i32 159 affine.store %3, %alloc[%arg5] : memref<10xi32> 160 } 161 %2 = scf.for %arg5 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg6 = %0) -> (memref<10x2xf32>) : i32 { 162 %alloc_5 = memref.alloc() : memref<2xf32> 163 affine.for %arg7 = 0 to 2 { 164 %16 = affine.load %arg4[%s, %arg7] : memref<10x2xf32> 165 affine.store %16, %alloc_5[%arg7] : memref<2xf32> 166 } 167 %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<1x2xf32> 168 affine.for %arg7 = 0 to 2 { 169 %16 = affine.load %alloc_5[%arg7] : memref<2xf32> 170 affine.store %16, %alloc_6[0, %arg7] : memref<1x2xf32> 171 } 172 %alloc_7 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32> 173 affine.for %arg7 = 0 to 10 { 174 affine.for %arg8 = 0 to 2 { 175 %16 = affine.load %alloc_6[0, %arg8] : memref<1x2xf32> 176 affine.store %16, %alloc_7[%arg7, %arg8] : memref<10x2xf32> 177 } 178 } 179 %alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32> 180 affine.for %arg7 = 0 to 10 { 181 affine.for %arg8 = 0 to 2 { 182 %16 = affine.load %alloc_7[%arg7, %arg8] : memref<10x2xf32> 183 %17 = affine.load %arg4[%arg7, %arg8] : memref<10x2xf32> 184 %18 = arith.subf %16, %17 : f32 185 affine.store %18, %alloc_8[%arg7, %arg8] : memref<10x2xf32> 186 } 187 } 188 scf.yield %alloc_8 : memref<10x2xf32> 189 // CHECK: scf.for 190 // CHECK: scf.for 191 // CHECK: affine.for %{{.*}} = 0 to 10 192 // CHECK-NEXT: affine.for %{{.*}} = 0 to 2 193 // CHECK-NOT: affine.for 194 // CHECK: scf.yield 195 } 196 %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32> 197 affine.for %arg5 = 0 to 10 { 198 affine.for %arg6 = 0 to 2 { 199 affine.store %cst_0, %alloc_2[%arg5, %arg6] : memref<10x2xf32> 200 } 201 } 202 %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32> 203 affine.for %arg5 = 0 to 10 { 204 affine.for %arg6 = 0 to 2 { 205 %3 = affine.load %alloc_2[%arg5, %arg6] : memref<10x2xf32> 206 %4 = affine.load %2[%arg5, %arg6] : memref<10x2xf32> 207 %5 = arith.mulf %3, %4 : f32 208 affine.store %5, %alloc_3[%arg5, %arg6] : memref<10x2xf32> 209 } 210 } 211 scf.yield %alloc_3 : memref<10x2xf32> 212 // The nests above will be fused as well. 213 // CHECK: affine.for %{{.*}} = 0 to 10 214 // CHECK-NEXT: affine.for %{{.*}} = 0 to 2 215 // CHECK-NOT: affine.for 216 // CHECK: scf.yield 217 } 218 affine.for %arg3 = 0 to 10 { 219 affine.for %arg4 = 0 to 2 { 220 %2 = affine.load %1[%arg3, %arg4] : memref<10x2xf32> 221 affine.store %2, %arg2[%arg3, %arg4] : memref<10x2xf32> 222 } 223 } 224 return 225} 226