xref: /llvm-project/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp (revision 4c48f016effde67d500fc95290096aec9f3bdb70)
1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test various loop fusion utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Analysis/Utils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16 #include "mlir/Dialect/Affine/LoopUtils.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Pass/Pass.h"
19 
20 #define DEBUG_TYPE "test-loop-fusion"
21 
22 using namespace mlir;
23 using namespace mlir::affine;
24 
25 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
26 
27 static llvm::cl::opt<bool> clTestDependenceCheck(
28     "test-loop-fusion-dependence-check",
29     llvm::cl::desc("Enable testing of loop fusion dependence check"),
30     llvm::cl::cat(clOptionsCategory));
31 
32 static llvm::cl::opt<bool> clTestSliceComputation(
33     "test-loop-fusion-slice-computation",
34     llvm::cl::desc("Enable testing of loop fusion slice computation"),
35     llvm::cl::cat(clOptionsCategory));
36 
37 static llvm::cl::opt<bool> clTestLoopFusionTransformation(
38     "test-loop-fusion-transformation",
39     llvm::cl::desc("Enable testing of loop fusion transformation"),
40     llvm::cl::cat(clOptionsCategory));
41 
42 namespace {
43 
44 struct TestLoopFusion
45     : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
46   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
47 
48   StringRef getArgument() const final { return "test-loop-fusion"; }
49   StringRef getDescription() const final {
50     return "Tests loop fusion utility functions.";
51   }
52   void runOnOperation() override;
53 };
54 
55 } // namespace
56 
57 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
58 // in range ['loopDepth' + 1, 'maxLoopDepth'].
59 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
60 // Returns false as IR is not transformed.
61 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
62                                 unsigned i, unsigned j, unsigned loopDepth,
63                                 unsigned maxLoopDepth) {
64   affine::ComputationSliceState sliceUnion;
65   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
66     FusionResult result =
67         affine::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
68     if (result.value == FusionResult::FailBlockDependence) {
69       srcForOp->emitRemark("block-level dependence preventing"
70                            " fusion of loop nest ")
71           << i << " into loop nest " << j << " at depth " << loopDepth;
72     }
73   }
74   return false;
75 }
76 
77 // Returns the index of 'op' in its block.
78 static unsigned getBlockIndex(Operation &op) {
79   unsigned index = 0;
80   for (auto &opX : *op.getBlock()) {
81     if (&op == &opX)
82       break;
83     ++index;
84   }
85   return index;
86 }
87 
88 // Returns a string representation of 'sliceUnion'.
89 static std::string
90 getSliceStr(const affine::ComputationSliceState &sliceUnion) {
91   std::string result;
92   llvm::raw_string_ostream os(result);
93   // Slice insertion point format [loop-depth, operation-block-index]
94   unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
95   unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
96   os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
97      << ")";
98   assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
99   os << " loop bounds: ";
100   for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
101     os << '[';
102     sliceUnion.lbs[k].print(os);
103     os << ", ";
104     sliceUnion.ubs[k].print(os);
105     os << "] ";
106   }
107   return os.str();
108 }
109 
110 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
111 /// in range ['loopDepth' + 1, 'maxLoopDepth'].
112 /// Emits a string representation of the slice union as a remark on 'loops[j]'
113 /// and marks this as incorrect slice if the slice is invalid. Returns false as
114 /// IR is not transformed.
115 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
116                                  unsigned i, unsigned j, unsigned loopDepth,
117                                  unsigned maxLoopDepth) {
118   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
119     affine::ComputationSliceState sliceUnion;
120     FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
121     if (result.value == FusionResult::Success) {
122       forOpB->emitRemark("slice (")
123           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
124           << " : " << getSliceStr(sliceUnion) << ")";
125     } else if (result.value == FusionResult::FailIncorrectSlice) {
126       forOpB->emitRemark("Incorrect slice (")
127           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
128           << " : " << getSliceStr(sliceUnion) << ")";
129     }
130   }
131   return false;
132 }
133 
134 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
135 // ['loopDepth' + 1, 'maxLoopDepth'].
136 // Returns true if loops were successfully fused, false otherwise.
137 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
138                                          unsigned i, unsigned j,
139                                          unsigned loopDepth,
140                                          unsigned maxLoopDepth) {
141   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
142     affine::ComputationSliceState sliceUnion;
143     FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
144     if (result.value == FusionResult::Success) {
145       affine::fuseLoops(forOpA, forOpB, sliceUnion);
146       // Note: 'forOpA' is removed to simplify test output. A proper loop
147       // fusion pass should check the data dependence graph and run memref
148       // region analysis to ensure removing 'forOpA' is safe.
149       forOpA.erase();
150       return true;
151     }
152   }
153   return false;
154 }
155 
156 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
157                                    unsigned, unsigned)>;
158 
159 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
160 // If 'return_on_change' is true, returns on first invocation of 'fn' which
161 // returns true.
162 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
163                          LoopFunc fn, bool returnOnChange = false) {
164   bool changed = false;
165   for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
166        ++loopDepth) {
167     auto &loops = depthToLoops[loopDepth];
168     unsigned numLoops = loops.size();
169     for (unsigned j = 0; j < numLoops; ++j) {
170       for (unsigned k = 0; k < numLoops; ++k) {
171         if (j != k)
172           changed |=
173               fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
174         if (changed && returnOnChange)
175           return true;
176       }
177     }
178   }
179   return changed;
180 }
181 
182 void TestLoopFusion::runOnOperation() {
183   std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
184   if (clTestLoopFusionTransformation) {
185     // Run loop fusion until a fixed point is reached.
186     do {
187       depthToLoops.clear();
188       // Gather all AffineForOps by loop depth.
189       gatherLoops(getOperation(), depthToLoops);
190 
191       // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
192     } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
193                           /*returnOnChange=*/true));
194     return;
195   }
196 
197   // Gather all AffineForOps by loop depth.
198   gatherLoops(getOperation(), depthToLoops);
199 
200   // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
201   if (clTestDependenceCheck)
202     iterateLoops(depthToLoops, testDependenceCheck);
203   if (clTestSliceComputation)
204     iterateLoops(depthToLoops, testSliceComputation);
205 }
206 
207 namespace mlir {
208 namespace test {
209 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
210 } // namespace test
211 } // namespace mlir
212