1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test various loop fusion utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Affine/Analysis/Utils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16 #include "mlir/Dialect/Affine/LoopUtils.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Pass/Pass.h"
19
20 #define DEBUG_TYPE "test-loop-fusion"
21
22 using namespace mlir;
23 using namespace mlir::affine;
24
25 namespace {
26
27 struct TestLoopFusion
28 : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon472513560111::TestLoopFusion29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
30
31 StringRef getArgument() const final { return "test-loop-fusion"; }
getDescription__anon472513560111::TestLoopFusion32 StringRef getDescription() const final {
33 return "Tests loop fusion utility functions.";
34 }
35 void runOnOperation() override;
36
37 TestLoopFusion() = default;
TestLoopFusion__anon472513560111::TestLoopFusion38 TestLoopFusion(const TestLoopFusion &pass) : PassWrapper(pass){};
39
40 Option<bool> clTestDependenceCheck{
41 *this, "test-loop-fusion-dependence-check",
42 llvm::cl::desc("Enable testing of loop fusion dependence check"),
43 llvm::cl::init(false)};
44
45 Option<bool> clTestSliceComputation{
46 *this, "test-loop-fusion-slice-computation",
47 llvm::cl::desc("Enable testing of loop fusion slice computation"),
48 llvm::cl::init(false)};
49
50 Option<bool> clTestLoopFusionTransformation{
51 *this, "test-loop-fusion-transformation",
52 llvm::cl::desc("Enable testing of loop fusion transformation"),
53 llvm::cl::init(false)};
54 };
55
56 } // namespace
57
58 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
59 // in range ['loopDepth' + 1, 'maxLoopDepth'].
60 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
61 // Returns false as IR is not transformed.
testDependenceCheck(AffineForOp srcForOp,AffineForOp dstForOp,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)62 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
63 unsigned i, unsigned j, unsigned loopDepth,
64 unsigned maxLoopDepth) {
65 affine::ComputationSliceState sliceUnion;
66 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
67 FusionResult result =
68 affine::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
69 if (result.value == FusionResult::FailBlockDependence) {
70 srcForOp->emitRemark("block-level dependence preventing"
71 " fusion of loop nest ")
72 << i << " into loop nest " << j << " at depth " << loopDepth;
73 }
74 }
75 return false;
76 }
77
78 // Returns the index of 'op' in its block.
getBlockIndex(Operation & op)79 static unsigned getBlockIndex(Operation &op) {
80 unsigned index = 0;
81 for (auto &opX : *op.getBlock()) {
82 if (&op == &opX)
83 break;
84 ++index;
85 }
86 return index;
87 }
88
89 // Returns a string representation of 'sliceUnion'.
90 static std::string
getSliceStr(const affine::ComputationSliceState & sliceUnion)91 getSliceStr(const affine::ComputationSliceState &sliceUnion) {
92 std::string result;
93 llvm::raw_string_ostream os(result);
94 // Slice insertion point format [loop-depth, operation-block-index]
95 unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
96 unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
97 os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
98 << ")";
99 assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
100 os << " loop bounds: ";
101 for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
102 os << '[';
103 sliceUnion.lbs[k].print(os);
104 os << ", ";
105 sliceUnion.ubs[k].print(os);
106 os << "] ";
107 }
108 return os.str();
109 }
110
111 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
112 /// in range ['loopDepth' + 1, 'maxLoopDepth'].
113 /// Emits a string representation of the slice union as a remark on 'loops[j]'
114 /// and marks this as incorrect slice if the slice is invalid. Returns false as
115 /// IR is not transformed.
testSliceComputation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)116 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
117 unsigned i, unsigned j, unsigned loopDepth,
118 unsigned maxLoopDepth) {
119 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
120 affine::ComputationSliceState sliceUnion;
121 FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
122 if (result.value == FusionResult::Success) {
123 forOpB->emitRemark("slice (")
124 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
125 << " : " << getSliceStr(sliceUnion) << ")";
126 } else if (result.value == FusionResult::FailIncorrectSlice) {
127 forOpB->emitRemark("Incorrect slice (")
128 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
129 << " : " << getSliceStr(sliceUnion) << ")";
130 }
131 }
132 return false;
133 }
134
135 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
136 // ['loopDepth' + 1, 'maxLoopDepth'].
137 // Returns true if loops were successfully fused, false otherwise.
testLoopFusionTransformation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)138 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
139 unsigned i, unsigned j,
140 unsigned loopDepth,
141 unsigned maxLoopDepth) {
142 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
143 affine::ComputationSliceState sliceUnion;
144 FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
145 if (result.value == FusionResult::Success) {
146 affine::fuseLoops(forOpA, forOpB, sliceUnion);
147 // Note: 'forOpA' is removed to simplify test output. A proper loop
148 // fusion pass should check the data dependence graph and run memref
149 // region analysis to ensure removing 'forOpA' is safe.
150 forOpA.erase();
151 return true;
152 }
153 }
154 return false;
155 }
156
157 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
158 unsigned, unsigned)>;
159
160 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
161 // If 'return_on_change' is true, returns on first invocation of 'fn' which
162 // returns true.
iterateLoops(ArrayRef<SmallVector<AffineForOp,2>> depthToLoops,LoopFunc fn,bool returnOnChange=false)163 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
164 LoopFunc fn, bool returnOnChange = false) {
165 bool changed = false;
166 for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
167 ++loopDepth) {
168 auto &loops = depthToLoops[loopDepth];
169 unsigned numLoops = loops.size();
170 for (unsigned j = 0; j < numLoops; ++j) {
171 for (unsigned k = 0; k < numLoops; ++k) {
172 if (j != k)
173 changed |=
174 fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
175 if (changed && returnOnChange)
176 return true;
177 }
178 }
179 }
180 return changed;
181 }
182
runOnOperation()183 void TestLoopFusion::runOnOperation() {
184 std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
185 if (clTestLoopFusionTransformation) {
186 // Run loop fusion until a fixed point is reached.
187 do {
188 depthToLoops.clear();
189 // Gather all AffineForOps by loop depth.
190 gatherLoops(getOperation(), depthToLoops);
191
192 // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
193 } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
194 /*returnOnChange=*/true));
195 return;
196 }
197
198 // Gather all AffineForOps by loop depth.
199 gatherLoops(getOperation(), depthToLoops);
200
201 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
202 if (clTestDependenceCheck)
203 iterateLoops(depthToLoops, testDependenceCheck);
204 if (clTestSliceComputation)
205 iterateLoops(depthToLoops, testSliceComputation);
206 }
207
208 namespace mlir {
209 namespace test {
registerTestLoopFusion()210 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
211 } // namespace test
212 } // namespace mlir
213