1 //===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Affine/LoopUtils.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/SCF/IR/SCF.h" 14 #include "mlir/Dialect/SCF/Utils/Utils.h" 15 #include "mlir/Transforms/Passes.h" 16 #include "mlir/Transforms/RegionUtils.h" 17 #include "llvm/Support/Debug.h" 18 19 #define PASS_NAME "loop-coalescing" 20 #define DEBUG_TYPE PASS_NAME 21 22 using namespace mlir; 23 24 namespace { 25 struct LoopCoalescingPass : public LoopCoalescingBase<LoopCoalescingPass> { 26 27 /// Walk either an scf.for or an affine.for to find a band to coalesce. 28 template <typename LoopOpTy> 29 static void walkLoop(LoopOpTy op) { 30 // Ignore nested loops. 31 if (op->template getParentOfType<LoopOpTy>()) 32 return; 33 34 SmallVector<LoopOpTy, 4> loops; 35 getPerfectlyNestedLoops(loops, op); 36 LLVM_DEBUG(llvm::dbgs() 37 << "found a perfect nest of depth " << loops.size() << '\n'); 38 39 // Look for a band of loops that can be coalesced, i.e. perfectly nested 40 // loops with bounds defined above some loop. 41 // 1. For each loop, find above which parent loop its operands are 42 // defined. 43 SmallVector<unsigned, 4> operandsDefinedAbove(loops.size()); 44 for (unsigned i = 0, e = loops.size(); i < e; ++i) { 45 operandsDefinedAbove[i] = i; 46 for (unsigned j = 0; j < i; ++j) { 47 if (areValuesDefinedAbove(loops[i].getOperands(), 48 loops[j].getRegion())) { 49 operandsDefinedAbove[i] = j; 50 break; 51 } 52 } 53 LLVM_DEBUG(llvm::dbgs() 54 << " bounds of loop " << i << " are known above depth " 55 << operandsDefinedAbove[i] << '\n'); 56 } 57 58 // 2. Identify bands of loops such that the operands of all of them are 59 // defined above the first loop in the band. Traverse the nest bottom-up 60 // so that modifications don't invalidate the inner loops. 61 for (unsigned end = loops.size(); end > 0; --end) { 62 unsigned start = 0; 63 for (; start < end - 1; ++start) { 64 auto maxPos = 65 *std::max_element(std::next(operandsDefinedAbove.begin(), start), 66 std::next(operandsDefinedAbove.begin(), end)); 67 if (maxPos > start) 68 continue; 69 70 assert(maxPos == start && 71 "expected loop bounds to be known at the start of the band"); 72 LLVM_DEBUG(llvm::dbgs() << " found coalesceable band from " << start 73 << " to " << end << '\n'); 74 75 auto band = 76 llvm::makeMutableArrayRef(loops.data() + start, end - start); 77 (void)coalesceLoops(band); 78 break; 79 } 80 // If a band was found and transformed, keep looking at the loops above 81 // the outermost transformed loop. 82 if (start != end - 1) 83 end = start + 1; 84 } 85 } 86 87 void runOnOperation() override { 88 func::FuncOp func = getOperation(); 89 func.walk([&](Operation *op) { 90 if (auto scfForOp = dyn_cast<scf::ForOp>(op)) 91 walkLoop(scfForOp); 92 else if (auto affineForOp = dyn_cast<AffineForOp>(op)) 93 walkLoop(affineForOp); 94 }); 95 } 96 }; 97 98 } // namespace 99 100 std::unique_ptr<OperationPass<func::FuncOp>> mlir::createLoopCoalescingPass() { 101 return std::make_unique<LoopCoalescingPass>(); 102 } 103