xref: /llvm-project/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp (revision 608a663c8ee485c42637d021d554c8d264d556b1)
1a70aa7bbSRiver Riddle //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle //
9a70aa7bbSRiver Riddle // This file implements a pass to test various loop fusion utility functions.
10a70aa7bbSRiver Riddle //
11a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
12a70aa7bbSRiver Riddle 
13a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/Utils.h"
14a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
15a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h"
1736550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
18a70aa7bbSRiver Riddle #include "mlir/Pass/Pass.h"
19a70aa7bbSRiver Riddle 
20a70aa7bbSRiver Riddle #define DEBUG_TYPE "test-loop-fusion"
21a70aa7bbSRiver Riddle 
22a70aa7bbSRiver Riddle using namespace mlir;
234c48f016SMatthias Springer using namespace mlir::affine;
24a70aa7bbSRiver Riddle 
25a70aa7bbSRiver Riddle namespace {
26a70aa7bbSRiver Riddle 
27a70aa7bbSRiver Riddle struct TestLoopFusion
2858ceae95SRiver Riddle     : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon472513560111::TestLoopFusion295e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
305e50dd04SRiver Riddle 
31a70aa7bbSRiver Riddle   StringRef getArgument() const final { return "test-loop-fusion"; }
getDescription__anon472513560111::TestLoopFusion32a70aa7bbSRiver Riddle   StringRef getDescription() const final {
33a70aa7bbSRiver Riddle     return "Tests loop fusion utility functions.";
34a70aa7bbSRiver Riddle   }
35a70aa7bbSRiver Riddle   void runOnOperation() override;
36*608a663cSPhilip Lassen 
37*608a663cSPhilip Lassen   TestLoopFusion() = default;
TestLoopFusion__anon472513560111::TestLoopFusion38*608a663cSPhilip Lassen   TestLoopFusion(const TestLoopFusion &pass) : PassWrapper(pass){};
39*608a663cSPhilip Lassen 
40*608a663cSPhilip Lassen   Option<bool> clTestDependenceCheck{
41*608a663cSPhilip Lassen       *this, "test-loop-fusion-dependence-check",
42*608a663cSPhilip Lassen       llvm::cl::desc("Enable testing of loop fusion dependence check"),
43*608a663cSPhilip Lassen       llvm::cl::init(false)};
44*608a663cSPhilip Lassen 
45*608a663cSPhilip Lassen   Option<bool> clTestSliceComputation{
46*608a663cSPhilip Lassen       *this, "test-loop-fusion-slice-computation",
47*608a663cSPhilip Lassen       llvm::cl::desc("Enable testing of loop fusion slice computation"),
48*608a663cSPhilip Lassen       llvm::cl::init(false)};
49*608a663cSPhilip Lassen 
50*608a663cSPhilip Lassen   Option<bool> clTestLoopFusionTransformation{
51*608a663cSPhilip Lassen       *this, "test-loop-fusion-transformation",
52*608a663cSPhilip Lassen       llvm::cl::desc("Enable testing of loop fusion transformation"),
53*608a663cSPhilip Lassen       llvm::cl::init(false)};
54a70aa7bbSRiver Riddle };
55a70aa7bbSRiver Riddle 
56a70aa7bbSRiver Riddle } // namespace
57a70aa7bbSRiver Riddle 
58a70aa7bbSRiver Riddle // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
59a70aa7bbSRiver Riddle // in range ['loopDepth' + 1, 'maxLoopDepth'].
60a70aa7bbSRiver Riddle // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
61a70aa7bbSRiver Riddle // Returns false as IR is not transformed.
testDependenceCheck(AffineForOp srcForOp,AffineForOp dstForOp,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)62a70aa7bbSRiver Riddle static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
63a70aa7bbSRiver Riddle                                 unsigned i, unsigned j, unsigned loopDepth,
64a70aa7bbSRiver Riddle                                 unsigned maxLoopDepth) {
654c48f016SMatthias Springer   affine::ComputationSliceState sliceUnion;
66a70aa7bbSRiver Riddle   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
67a70aa7bbSRiver Riddle     FusionResult result =
684c48f016SMatthias Springer         affine::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
69a70aa7bbSRiver Riddle     if (result.value == FusionResult::FailBlockDependence) {
70a70aa7bbSRiver Riddle       srcForOp->emitRemark("block-level dependence preventing"
71a70aa7bbSRiver Riddle                            " fusion of loop nest ")
72a70aa7bbSRiver Riddle           << i << " into loop nest " << j << " at depth " << loopDepth;
73a70aa7bbSRiver Riddle     }
74a70aa7bbSRiver Riddle   }
75a70aa7bbSRiver Riddle   return false;
76a70aa7bbSRiver Riddle }
77a70aa7bbSRiver Riddle 
78a70aa7bbSRiver Riddle // Returns the index of 'op' in its block.
getBlockIndex(Operation & op)79a70aa7bbSRiver Riddle static unsigned getBlockIndex(Operation &op) {
80a70aa7bbSRiver Riddle   unsigned index = 0;
81a70aa7bbSRiver Riddle   for (auto &opX : *op.getBlock()) {
82a70aa7bbSRiver Riddle     if (&op == &opX)
83a70aa7bbSRiver Riddle       break;
84a70aa7bbSRiver Riddle     ++index;
85a70aa7bbSRiver Riddle   }
86a70aa7bbSRiver Riddle   return index;
87a70aa7bbSRiver Riddle }
88a70aa7bbSRiver Riddle 
89a70aa7bbSRiver Riddle // Returns a string representation of 'sliceUnion'.
904c48f016SMatthias Springer static std::string
getSliceStr(const affine::ComputationSliceState & sliceUnion)914c48f016SMatthias Springer getSliceStr(const affine::ComputationSliceState &sliceUnion) {
92a70aa7bbSRiver Riddle   std::string result;
93a70aa7bbSRiver Riddle   llvm::raw_string_ostream os(result);
94a70aa7bbSRiver Riddle   // Slice insertion point format [loop-depth, operation-block-index]
95a70aa7bbSRiver Riddle   unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
96a70aa7bbSRiver Riddle   unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
97a70aa7bbSRiver Riddle   os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
98a70aa7bbSRiver Riddle      << ")";
99a70aa7bbSRiver Riddle   assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
100a70aa7bbSRiver Riddle   os << " loop bounds: ";
101a70aa7bbSRiver Riddle   for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
102a70aa7bbSRiver Riddle     os << '[';
103a70aa7bbSRiver Riddle     sliceUnion.lbs[k].print(os);
104a70aa7bbSRiver Riddle     os << ", ";
105a70aa7bbSRiver Riddle     sliceUnion.ubs[k].print(os);
106a70aa7bbSRiver Riddle     os << "] ";
107a70aa7bbSRiver Riddle   }
108a70aa7bbSRiver Riddle   return os.str();
109a70aa7bbSRiver Riddle }
110a70aa7bbSRiver Riddle 
111a70aa7bbSRiver Riddle /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
112a70aa7bbSRiver Riddle /// in range ['loopDepth' + 1, 'maxLoopDepth'].
113a70aa7bbSRiver Riddle /// Emits a string representation of the slice union as a remark on 'loops[j]'
114a70aa7bbSRiver Riddle /// and marks this as incorrect slice if the slice is invalid. Returns false as
115a70aa7bbSRiver Riddle /// IR is not transformed.
testSliceComputation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)116a70aa7bbSRiver Riddle static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
117a70aa7bbSRiver Riddle                                  unsigned i, unsigned j, unsigned loopDepth,
118a70aa7bbSRiver Riddle                                  unsigned maxLoopDepth) {
119a70aa7bbSRiver Riddle   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
1204c48f016SMatthias Springer     affine::ComputationSliceState sliceUnion;
1214c48f016SMatthias Springer     FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
122a70aa7bbSRiver Riddle     if (result.value == FusionResult::Success) {
123a70aa7bbSRiver Riddle       forOpB->emitRemark("slice (")
124a70aa7bbSRiver Riddle           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
125a70aa7bbSRiver Riddle           << " : " << getSliceStr(sliceUnion) << ")";
126a70aa7bbSRiver Riddle     } else if (result.value == FusionResult::FailIncorrectSlice) {
127a70aa7bbSRiver Riddle       forOpB->emitRemark("Incorrect slice (")
128a70aa7bbSRiver Riddle           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
129a70aa7bbSRiver Riddle           << " : " << getSliceStr(sliceUnion) << ")";
130a70aa7bbSRiver Riddle     }
131a70aa7bbSRiver Riddle   }
132a70aa7bbSRiver Riddle   return false;
133a70aa7bbSRiver Riddle }
134a70aa7bbSRiver Riddle 
135a70aa7bbSRiver Riddle // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
136a70aa7bbSRiver Riddle // ['loopDepth' + 1, 'maxLoopDepth'].
137a70aa7bbSRiver Riddle // Returns true if loops were successfully fused, false otherwise.
testLoopFusionTransformation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)138a70aa7bbSRiver Riddle static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
139a70aa7bbSRiver Riddle                                          unsigned i, unsigned j,
140a70aa7bbSRiver Riddle                                          unsigned loopDepth,
141a70aa7bbSRiver Riddle                                          unsigned maxLoopDepth) {
142a70aa7bbSRiver Riddle   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
1434c48f016SMatthias Springer     affine::ComputationSliceState sliceUnion;
1444c48f016SMatthias Springer     FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
145a70aa7bbSRiver Riddle     if (result.value == FusionResult::Success) {
1464c48f016SMatthias Springer       affine::fuseLoops(forOpA, forOpB, sliceUnion);
147a70aa7bbSRiver Riddle       // Note: 'forOpA' is removed to simplify test output. A proper loop
148a70aa7bbSRiver Riddle       // fusion pass should check the data dependence graph and run memref
149a70aa7bbSRiver Riddle       // region analysis to ensure removing 'forOpA' is safe.
150a70aa7bbSRiver Riddle       forOpA.erase();
151a70aa7bbSRiver Riddle       return true;
152a70aa7bbSRiver Riddle     }
153a70aa7bbSRiver Riddle   }
154a70aa7bbSRiver Riddle   return false;
155a70aa7bbSRiver Riddle }
156a70aa7bbSRiver Riddle 
157a70aa7bbSRiver Riddle using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
158a70aa7bbSRiver Riddle                                    unsigned, unsigned)>;
159a70aa7bbSRiver Riddle 
160a70aa7bbSRiver Riddle // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
161a70aa7bbSRiver Riddle // If 'return_on_change' is true, returns on first invocation of 'fn' which
162a70aa7bbSRiver Riddle // returns true.
iterateLoops(ArrayRef<SmallVector<AffineForOp,2>> depthToLoops,LoopFunc fn,bool returnOnChange=false)163a70aa7bbSRiver Riddle static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
164a70aa7bbSRiver Riddle                          LoopFunc fn, bool returnOnChange = false) {
165a70aa7bbSRiver Riddle   bool changed = false;
166a70aa7bbSRiver Riddle   for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
167a70aa7bbSRiver Riddle        ++loopDepth) {
168a70aa7bbSRiver Riddle     auto &loops = depthToLoops[loopDepth];
169a70aa7bbSRiver Riddle     unsigned numLoops = loops.size();
170a70aa7bbSRiver Riddle     for (unsigned j = 0; j < numLoops; ++j) {
171a70aa7bbSRiver Riddle       for (unsigned k = 0; k < numLoops; ++k) {
172a70aa7bbSRiver Riddle         if (j != k)
173a70aa7bbSRiver Riddle           changed |=
174a70aa7bbSRiver Riddle               fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
175a70aa7bbSRiver Riddle         if (changed && returnOnChange)
176a70aa7bbSRiver Riddle           return true;
177a70aa7bbSRiver Riddle       }
178a70aa7bbSRiver Riddle     }
179a70aa7bbSRiver Riddle   }
180a70aa7bbSRiver Riddle   return changed;
181a70aa7bbSRiver Riddle }
182a70aa7bbSRiver Riddle 
runOnOperation()183a70aa7bbSRiver Riddle void TestLoopFusion::runOnOperation() {
184a70aa7bbSRiver Riddle   std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
185a70aa7bbSRiver Riddle   if (clTestLoopFusionTransformation) {
186a70aa7bbSRiver Riddle     // Run loop fusion until a fixed point is reached.
187a70aa7bbSRiver Riddle     do {
188a70aa7bbSRiver Riddle       depthToLoops.clear();
189a70aa7bbSRiver Riddle       // Gather all AffineForOps by loop depth.
190a70aa7bbSRiver Riddle       gatherLoops(getOperation(), depthToLoops);
191a70aa7bbSRiver Riddle 
192a70aa7bbSRiver Riddle       // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
193a70aa7bbSRiver Riddle     } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
194a70aa7bbSRiver Riddle                           /*returnOnChange=*/true));
195a70aa7bbSRiver Riddle     return;
196a70aa7bbSRiver Riddle   }
197a70aa7bbSRiver Riddle 
198a70aa7bbSRiver Riddle   // Gather all AffineForOps by loop depth.
199a70aa7bbSRiver Riddle   gatherLoops(getOperation(), depthToLoops);
200a70aa7bbSRiver Riddle 
201a70aa7bbSRiver Riddle   // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
202a70aa7bbSRiver Riddle   if (clTestDependenceCheck)
203a70aa7bbSRiver Riddle     iterateLoops(depthToLoops, testDependenceCheck);
204a70aa7bbSRiver Riddle   if (clTestSliceComputation)
205a70aa7bbSRiver Riddle     iterateLoops(depthToLoops, testSliceComputation);
206a70aa7bbSRiver Riddle }
207a70aa7bbSRiver Riddle 
208a70aa7bbSRiver Riddle namespace mlir {
209a70aa7bbSRiver Riddle namespace test {
registerTestLoopFusion()210a70aa7bbSRiver Riddle void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
211a70aa7bbSRiver Riddle } // namespace test
212a70aa7bbSRiver Riddle } // namespace mlir
213