xref: /llvm-project/mlir/test/Dialect/Affine/loop-fusion-inner.mlir (revision c79ffb02bbd79c9abe0add4a5bcef375af0a9755)
1// RUN: mlir-opt -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer fusion-maximal}))' %s | FileCheck %s
2
3// Test fusion of affine nests inside other region-holding ops (scf.for in the
4// test case below).
5
6// CHECK-LABEL: func @fusion_inner_simple
7func.func @fusion_inner_simple(%A : memref<10xf32>) {
8  %cst = arith.constant 0.0 : f32
9
10  affine.for %i = 0 to 100 {
11    %B = memref.alloc() : memref<10xf32>
12    %C = memref.alloc() : memref<10xf32>
13
14    affine.for %j = 0 to 10 {
15      %v = affine.load %A[%j] : memref<10xf32>
16      affine.store %v, %B[%j] : memref<10xf32>
17    }
18
19    affine.for %j = 0 to 10 {
20      %v = affine.load %B[%j] : memref<10xf32>
21      affine.store %v, %C[%j] : memref<10xf32>
22    }
23  }
24
25  // CHECK:      affine.for %{{.*}} = 0 to 100
26  // CHECK-NEXT:   memref.alloc
27  // CHECK-NEXT:   memref.alloc
28  // CHECK-NEXT:   affine.for %{{.*}} = 0 to 10
29  // CHECK-NOT:    affine.for
30
31  return
32}
33
34// CHECK-LABEL: func @fusion_inner_simple_scf
35func.func @fusion_inner_simple_scf(%A : memref<10xf32>) {
36  %c0 = arith.constant 0 : index
37  %c1 = arith.constant 1 : index
38  %c100 = arith.constant 100 : index
39  %cst = arith.constant 0.0 : f32
40
41  scf.for %i = %c0 to %c100 step %c1 {
42    %B = memref.alloc() : memref<10xf32>
43    %C = memref.alloc() : memref<10xf32>
44
45    affine.for %j = 0 to 10 {
46      %v = affine.load %A[%j] : memref<10xf32>
47      affine.store %v, %B[%j] : memref<10xf32>
48    }
49
50    affine.for %j = 0 to 10 {
51      %v = affine.load %B[%j] : memref<10xf32>
52      affine.store %v, %C[%j] : memref<10xf32>
53    }
54  }
55  // CHECK:      scf.for
56  // CHECK-NEXT:   memref.alloc
57  // CHECK-NEXT:   memref.alloc
58  // CHECK-NEXT:   affine.for %{{.*}} = 0 to 10
59  // CHECK-NOT:    affine.for
60  return
61}
62
63// CHECK-LABEL: func @fusion_inner_multiple_nests
64func.func @fusion_inner_multiple_nests() {
65  %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x4xi8>
66  %alloc_10 = memref.alloc() : memref<8x4xi32>
67  affine.for %arg8 = 0 to 4 {
68    %alloc_14 = memref.alloc() : memref<4xi8>
69    %alloc_15 = memref.alloc() : memref<8x4xi8>
70    affine.for %arg9 = 0 to 4 {
71      %0 = affine.load %alloc_5[%arg9, %arg8] : memref<4x4xi8>
72      affine.store %0, %alloc_14[%arg9] : memref<4xi8>
73    }
74    %alloc_16 = memref.alloc() : memref<4xi8>
75    affine.for %arg9 = 0 to 4 {
76      %0 = affine.load %alloc_14[%arg9] : memref<4xi8>
77      affine.store %0, %alloc_16[%arg9] : memref<4xi8>
78    }
79    affine.for %arg9 = 0 to 2 {
80      %0 = affine.load %alloc_15[%arg9 * 4, 0] : memref<8x4xi8>
81      %1 = affine.load %alloc_16[0] : memref<4xi8>
82      %2 = affine.load %alloc_10[%arg9 * 4, %arg8] : memref<8x4xi32>
83      %3 = arith.muli %0, %1 : i8
84      %4 = arith.extsi %3 : i8 to i32
85      %5 = arith.addi %4, %2 : i32
86      affine.store %5, %alloc_10[%arg9 * 4 + 3, %arg8] : memref<8x4xi32>
87    }
88    memref.dealloc %alloc_16 : memref<4xi8>
89  }
90  // CHECK:      affine.for %{{.*}} = 0 to 4 {
91  // Everything inside fused into two nests (the second will be DCE'd).
92  // CHECK-NEXT:   memref.alloc() : memref<4xi8>
93  // CHECK-NEXT:   memref.alloc() : memref<1xi8>
94  // CHECK-NEXT:   memref.alloc() : memref<1xi8>
95  // CHECK-NEXT:   memref.alloc() : memref<8x4xi8>
96  // CHECK-NEXT:   memref.alloc() : memref<4xi8>
97  // CHECK-NEXT:   affine.for %{{.*}} = 0 to 2 {
98  // CHECK:        }
99  // CHECK:        affine.for %{{.*}} = 0 to 4 {
100  // CHECK:        }
101  // CHECK-NEXT:   memref.dealloc
102  // CHECK-NEXT: }
103  // CHECK-NEXT: return
104  return
105}
106
107// CHECK-LABEL: func @fusion_inside_scf_while
108func.func @fusion_inside_scf_while(%A : memref<10xf32>) {
109  %c0 = arith.constant 0 : index
110  %c1 = arith.constant 1 : index
111  %c100 = arith.constant 100 : index
112  %cst = arith.constant 0.0 : f32
113
114  %0 = scf.while (%arg3 = %cst) : (f32) -> (f32) {
115    %1 = arith.cmpf ult, %arg3, %cst : f32
116    scf.condition(%1) %arg3 : f32
117  } do {
118  ^bb0(%arg5: f32):
119
120    %B = memref.alloc() : memref<10xf32>
121    %C = memref.alloc() : memref<10xf32>
122
123    affine.for %j = 0 to 10 {
124      %v = affine.load %A[%j] : memref<10xf32>
125      affine.store %v, %B[%j] : memref<10xf32>
126    }
127
128    affine.for %j = 0 to 10 {
129      %v = affine.load %B[%j] : memref<10xf32>
130      affine.store %v, %C[%j] : memref<10xf32>
131    }
132    %1 = arith.mulf %arg5, %cst : f32
133    scf.yield %1 : f32
134  }
135  // CHECK:      scf.while
136  // CHECK:        affine.for %{{.*}} = 0 to 10
137  // CHECK-NOT:    affine.for
138  // CHECK:        scf.yield
139  return
140}
141
142
143memref.global "private" constant @__constant_10x2xf32 : memref<10x2xf32> = dense<0.000000e+00>
144
145// CHECK-LABEL: func @fusion_inner_long
146func.func @fusion_inner_long(%arg0: memref<10x2xf32>, %arg1: memref<10x10xf32>, %arg2: memref<10x2xf32>, %s: index) {
147  %c0 = arith.constant 0 : index
148  %cst_0 = arith.constant 1.000000e-03 : f32
149  %c9 = arith.constant 9 : index
150  %c10_i32 = arith.constant 10 : i32
151  %c1_i32 = arith.constant 1 : i32
152  %c100_i32 = arith.constant 100 : i32
153  %c0_i32 = arith.constant 0 : i32
154  %0 = memref.get_global @__constant_10x2xf32 : memref<10x2xf32>
155  %1 = scf.for %arg3 = %c0_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %arg0) -> (memref<10x2xf32>)  : i32 {
156    %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
157    affine.for %arg5 = 0 to 10 {
158      %3 = arith.index_cast %arg5 : index to i32
159      affine.store %3, %alloc[%arg5] : memref<10xi32>
160    }
161    %2 = scf.for %arg5 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg6 = %0) -> (memref<10x2xf32>)  : i32 {
162      %alloc_5 = memref.alloc() : memref<2xf32>
163      affine.for %arg7 = 0 to 2 {
164        %16 = affine.load %arg4[%s, %arg7] : memref<10x2xf32>
165        affine.store %16, %alloc_5[%arg7] : memref<2xf32>
166      }
167      %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<1x2xf32>
168      affine.for %arg7 = 0 to 2 {
169        %16 = affine.load %alloc_5[%arg7] : memref<2xf32>
170        affine.store %16, %alloc_6[0, %arg7] : memref<1x2xf32>
171      }
172      %alloc_7 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
173      affine.for %arg7 = 0 to 10 {
174        affine.for %arg8 = 0 to 2 {
175          %16 = affine.load %alloc_6[0, %arg8] : memref<1x2xf32>
176          affine.store %16, %alloc_7[%arg7, %arg8] : memref<10x2xf32>
177        }
178      }
179      %alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
180      affine.for %arg7 = 0 to 10 {
181        affine.for %arg8 = 0 to 2 {
182          %16 = affine.load %alloc_7[%arg7, %arg8] : memref<10x2xf32>
183          %17 = affine.load %arg4[%arg7, %arg8] : memref<10x2xf32>
184          %18 = arith.subf %16, %17 : f32
185          affine.store %18, %alloc_8[%arg7, %arg8] : memref<10x2xf32>
186        }
187      }
188      scf.yield %alloc_8 : memref<10x2xf32>
189      // CHECK:      scf.for
190      // CHECK:        scf.for
191      // CHECK:          affine.for %{{.*}} = 0 to 10
192      // CHECK-NEXT:       affine.for %{{.*}} = 0 to 2
193      // CHECK-NOT:      affine.for
194      // CHECK:          scf.yield
195    }
196    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
197    affine.for %arg5 = 0 to 10 {
198      affine.for %arg6 = 0 to 2 {
199        affine.store %cst_0, %alloc_2[%arg5, %arg6] : memref<10x2xf32>
200      }
201    }
202    %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
203    affine.for %arg5 = 0 to 10 {
204      affine.for %arg6 = 0 to 2 {
205        %3 = affine.load %alloc_2[%arg5, %arg6] : memref<10x2xf32>
206        %4 = affine.load %2[%arg5, %arg6] : memref<10x2xf32>
207        %5 = arith.mulf %3, %4 : f32
208        affine.store %5, %alloc_3[%arg5, %arg6] : memref<10x2xf32>
209      }
210    }
211    scf.yield %alloc_3 : memref<10x2xf32>
212    // The nests above will be fused as well.
213    // CHECK:      affine.for %{{.*}} = 0 to 10
214    // CHECK-NEXT:   affine.for %{{.*}} = 0 to 2
215    // CHECK-NOT:  affine.for
216    // CHECK:      scf.yield
217  }
218  affine.for %arg3 = 0 to 10 {
219    affine.for %arg4 = 0 to 2 {
220      %2 = affine.load %1[%arg3, %arg4] : memref<10x2xf32>
221      affine.store %2, %arg2[%arg3, %arg4] : memref<10x2xf32>
222    }
223  }
224  return
225}
226