xref: /llvm-project/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1 //===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Affine/LoopUtils.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/SCF/IR/SCF.h"
14 #include "mlir/Dialect/SCF/Utils/Utils.h"
15 #include "mlir/Transforms/Passes.h"
16 #include "mlir/Transforms/RegionUtils.h"
17 #include "llvm/Support/Debug.h"
18 
19 #define PASS_NAME "loop-coalescing"
20 #define DEBUG_TYPE PASS_NAME
21 
22 using namespace mlir;
23 
24 namespace {
25 struct LoopCoalescingPass : public LoopCoalescingBase<LoopCoalescingPass> {
26 
27   /// Walk either an scf.for or an affine.for to find a band to coalesce.
28   template <typename LoopOpTy>
29   static void walkLoop(LoopOpTy op) {
30     // Ignore nested loops.
31     if (op->template getParentOfType<LoopOpTy>())
32       return;
33 
34     SmallVector<LoopOpTy, 4> loops;
35     getPerfectlyNestedLoops(loops, op);
36     LLVM_DEBUG(llvm::dbgs()
37                << "found a perfect nest of depth " << loops.size() << '\n');
38 
39     // Look for a band of loops that can be coalesced, i.e. perfectly nested
40     // loops with bounds defined above some loop.
41     // 1. For each loop, find above which parent loop its operands are
42     // defined.
43     SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
44     for (unsigned i = 0, e = loops.size(); i < e; ++i) {
45       operandsDefinedAbove[i] = i;
46       for (unsigned j = 0; j < i; ++j) {
47         if (areValuesDefinedAbove(loops[i].getOperands(),
48                                   loops[j].getRegion())) {
49           operandsDefinedAbove[i] = j;
50           break;
51         }
52       }
53       LLVM_DEBUG(llvm::dbgs()
54                  << "  bounds of loop " << i << " are known above depth "
55                  << operandsDefinedAbove[i] << '\n');
56     }
57 
58     // 2. Identify bands of loops such that the operands of all of them are
59     // defined above the first loop in the band.  Traverse the nest bottom-up
60     // so that modifications don't invalidate the inner loops.
61     for (unsigned end = loops.size(); end > 0; --end) {
62       unsigned start = 0;
63       for (; start < end - 1; ++start) {
64         auto maxPos =
65             *std::max_element(std::next(operandsDefinedAbove.begin(), start),
66                               std::next(operandsDefinedAbove.begin(), end));
67         if (maxPos > start)
68           continue;
69 
70         assert(maxPos == start &&
71                "expected loop bounds to be known at the start of the band");
72         LLVM_DEBUG(llvm::dbgs() << "  found coalesceable band from " << start
73                                 << " to " << end << '\n');
74 
75         auto band =
76             llvm::makeMutableArrayRef(loops.data() + start, end - start);
77         (void)coalesceLoops(band);
78         break;
79       }
80       // If a band was found and transformed, keep looking at the loops above
81       // the outermost transformed loop.
82       if (start != end - 1)
83         end = start + 1;
84     }
85   }
86 
87   void runOnOperation() override {
88     func::FuncOp func = getOperation();
89     func.walk([&](Operation *op) {
90       if (auto scfForOp = dyn_cast<scf::ForOp>(op))
91         walkLoop(scfForOp);
92       else if (auto affineForOp = dyn_cast<AffineForOp>(op))
93         walkLoop(affineForOp);
94     });
95   }
96 };
97 
98 } // namespace
99 
100 std::unique_ptr<OperationPass<func::FuncOp>> mlir::createLoopCoalescingPass() {
101   return std::make_unique<LoopCoalescingPass>();
102 }
103