xref: /llvm-project/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp (revision 608a663c8ee485c42637d021d554c8d264d556b1)
1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test various loop fusion utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Analysis/Utils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16 #include "mlir/Dialect/Affine/LoopUtils.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Pass/Pass.h"
19 
20 #define DEBUG_TYPE "test-loop-fusion"
21 
22 using namespace mlir;
23 using namespace mlir::affine;
24 
25 namespace {
26 
27 struct TestLoopFusion
28     : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon472513560111::TestLoopFusion29   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
30 
31   StringRef getArgument() const final { return "test-loop-fusion"; }
getDescription__anon472513560111::TestLoopFusion32   StringRef getDescription() const final {
33     return "Tests loop fusion utility functions.";
34   }
35   void runOnOperation() override;
36 
37   TestLoopFusion() = default;
TestLoopFusion__anon472513560111::TestLoopFusion38   TestLoopFusion(const TestLoopFusion &pass) : PassWrapper(pass){};
39 
40   Option<bool> clTestDependenceCheck{
41       *this, "test-loop-fusion-dependence-check",
42       llvm::cl::desc("Enable testing of loop fusion dependence check"),
43       llvm::cl::init(false)};
44 
45   Option<bool> clTestSliceComputation{
46       *this, "test-loop-fusion-slice-computation",
47       llvm::cl::desc("Enable testing of loop fusion slice computation"),
48       llvm::cl::init(false)};
49 
50   Option<bool> clTestLoopFusionTransformation{
51       *this, "test-loop-fusion-transformation",
52       llvm::cl::desc("Enable testing of loop fusion transformation"),
53       llvm::cl::init(false)};
54 };
55 
56 } // namespace
57 
58 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
59 // in range ['loopDepth' + 1, 'maxLoopDepth'].
60 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
61 // Returns false as IR is not transformed.
testDependenceCheck(AffineForOp srcForOp,AffineForOp dstForOp,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)62 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
63                                 unsigned i, unsigned j, unsigned loopDepth,
64                                 unsigned maxLoopDepth) {
65   affine::ComputationSliceState sliceUnion;
66   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
67     FusionResult result =
68         affine::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
69     if (result.value == FusionResult::FailBlockDependence) {
70       srcForOp->emitRemark("block-level dependence preventing"
71                            " fusion of loop nest ")
72           << i << " into loop nest " << j << " at depth " << loopDepth;
73     }
74   }
75   return false;
76 }
77 
78 // Returns the index of 'op' in its block.
getBlockIndex(Operation & op)79 static unsigned getBlockIndex(Operation &op) {
80   unsigned index = 0;
81   for (auto &opX : *op.getBlock()) {
82     if (&op == &opX)
83       break;
84     ++index;
85   }
86   return index;
87 }
88 
89 // Returns a string representation of 'sliceUnion'.
90 static std::string
getSliceStr(const affine::ComputationSliceState & sliceUnion)91 getSliceStr(const affine::ComputationSliceState &sliceUnion) {
92   std::string result;
93   llvm::raw_string_ostream os(result);
94   // Slice insertion point format [loop-depth, operation-block-index]
95   unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
96   unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
97   os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
98      << ")";
99   assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
100   os << " loop bounds: ";
101   for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
102     os << '[';
103     sliceUnion.lbs[k].print(os);
104     os << ", ";
105     sliceUnion.ubs[k].print(os);
106     os << "] ";
107   }
108   return os.str();
109 }
110 
111 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
112 /// in range ['loopDepth' + 1, 'maxLoopDepth'].
113 /// Emits a string representation of the slice union as a remark on 'loops[j]'
114 /// and marks this as incorrect slice if the slice is invalid. Returns false as
115 /// IR is not transformed.
testSliceComputation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)116 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
117                                  unsigned i, unsigned j, unsigned loopDepth,
118                                  unsigned maxLoopDepth) {
119   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
120     affine::ComputationSliceState sliceUnion;
121     FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
122     if (result.value == FusionResult::Success) {
123       forOpB->emitRemark("slice (")
124           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
125           << " : " << getSliceStr(sliceUnion) << ")";
126     } else if (result.value == FusionResult::FailIncorrectSlice) {
127       forOpB->emitRemark("Incorrect slice (")
128           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
129           << " : " << getSliceStr(sliceUnion) << ")";
130     }
131   }
132   return false;
133 }
134 
135 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
136 // ['loopDepth' + 1, 'maxLoopDepth'].
137 // Returns true if loops were successfully fused, false otherwise.
testLoopFusionTransformation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)138 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
139                                          unsigned i, unsigned j,
140                                          unsigned loopDepth,
141                                          unsigned maxLoopDepth) {
142   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
143     affine::ComputationSliceState sliceUnion;
144     FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
145     if (result.value == FusionResult::Success) {
146       affine::fuseLoops(forOpA, forOpB, sliceUnion);
147       // Note: 'forOpA' is removed to simplify test output. A proper loop
148       // fusion pass should check the data dependence graph and run memref
149       // region analysis to ensure removing 'forOpA' is safe.
150       forOpA.erase();
151       return true;
152     }
153   }
154   return false;
155 }
156 
157 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
158                                    unsigned, unsigned)>;
159 
160 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
161 // If 'return_on_change' is true, returns on first invocation of 'fn' which
162 // returns true.
iterateLoops(ArrayRef<SmallVector<AffineForOp,2>> depthToLoops,LoopFunc fn,bool returnOnChange=false)163 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
164                          LoopFunc fn, bool returnOnChange = false) {
165   bool changed = false;
166   for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
167        ++loopDepth) {
168     auto &loops = depthToLoops[loopDepth];
169     unsigned numLoops = loops.size();
170     for (unsigned j = 0; j < numLoops; ++j) {
171       for (unsigned k = 0; k < numLoops; ++k) {
172         if (j != k)
173           changed |=
174               fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
175         if (changed && returnOnChange)
176           return true;
177       }
178     }
179   }
180   return changed;
181 }
182 
runOnOperation()183 void TestLoopFusion::runOnOperation() {
184   std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
185   if (clTestLoopFusionTransformation) {
186     // Run loop fusion until a fixed point is reached.
187     do {
188       depthToLoops.clear();
189       // Gather all AffineForOps by loop depth.
190       gatherLoops(getOperation(), depthToLoops);
191 
192       // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
193     } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
194                           /*returnOnChange=*/true));
195     return;
196   }
197 
198   // Gather all AffineForOps by loop depth.
199   gatherLoops(getOperation(), depthToLoops);
200 
201   // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
202   if (clTestDependenceCheck)
203     iterateLoops(depthToLoops, testDependenceCheck);
204   if (clTestSliceComputation)
205     iterateLoops(depthToLoops, testSliceComputation);
206 }
207 
208 namespace mlir {
209 namespace test {
registerTestLoopFusion()210 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
211 } // namespace test
212 } // namespace mlir
213