1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion on parallel loops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/Passes.h" 14 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/OpDefinition.h" 21 #include "mlir/Interfaces/SideEffectInterfaces.h" 22 23 namespace mlir { 24 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION 25 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::scf; 30 31 /// Verify there are no nested ParallelOps. 32 static bool hasNestedParallelOp(ParallelOp ploop) { 33 auto walkResult = 34 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); 35 return walkResult.wasInterrupted(); 36 } 37 38 /// Verify equal iteration spaces. 39 static bool equalIterationSpaces(ParallelOp firstPloop, 40 ParallelOp secondPloop) { 41 if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) 42 return false; 43 44 auto matchOperands = [&](const OperandRange &lhs, 45 const OperandRange &rhs) -> bool { 46 // TODO: Extend this to support aliases and equal constants. 47 return std::equal(lhs.begin(), lhs.end(), rhs.begin()); 48 }; 49 return matchOperands(firstPloop.getLowerBound(), 50 secondPloop.getLowerBound()) && 51 matchOperands(firstPloop.getUpperBound(), 52 secondPloop.getUpperBound()) && 53 matchOperands(firstPloop.getStep(), secondPloop.getStep()); 54 } 55 56 /// Checks if the parallel loops have mixed access to the same buffers. Returns 57 /// `true` if the first parallel loop writes to the same indices that the second 58 /// loop reads. 59 static bool haveNoReadsAfterWriteExceptSameIndex( 60 ParallelOp firstPloop, ParallelOp secondPloop, 61 const BlockAndValueMapping &firstToSecondPloopIndices) { 62 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores; 63 firstPloop.getBody()->walk([&](memref::StoreOp store) { 64 bufferStores[store.getMemRef()].push_back(store.getIndices()); 65 }); 66 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { 67 // Stop if the memref is defined in secondPloop body. Careful alias analysis 68 // is needed. 69 auto *memrefDef = load.getMemRef().getDefiningOp(); 70 if (memrefDef && memrefDef->getBlock() == load->getBlock()) 71 return WalkResult::interrupt(); 72 73 auto write = bufferStores.find(load.getMemRef()); 74 if (write == bufferStores.end()) 75 return WalkResult::advance(); 76 77 // Allow only single write access per buffer. 78 if (write->second.size() != 1) 79 return WalkResult::interrupt(); 80 81 // Check that the load indices of secondPloop coincide with store indices of 82 // firstPloop for the same memrefs. 83 auto storeIndices = write->second.front(); 84 auto loadIndices = load.getIndices(); 85 if (storeIndices.size() != loadIndices.size()) 86 return WalkResult::interrupt(); 87 for (int i = 0, e = storeIndices.size(); i < e; ++i) { 88 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != 89 loadIndices[i]) 90 return WalkResult::interrupt(); 91 } 92 return WalkResult::advance(); 93 }); 94 return !walkResult.wasInterrupted(); 95 } 96 97 /// Analyzes dependencies in the most primitive way by checking simple read and 98 /// write patterns. 99 static LogicalResult 100 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, 101 const BlockAndValueMapping &firstToSecondPloopIndices) { 102 if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop, 103 firstToSecondPloopIndices)) 104 return failure(); 105 106 BlockAndValueMapping secondToFirstPloopIndices; 107 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), 108 firstPloop.getBody()->getArguments()); 109 return success(haveNoReadsAfterWriteExceptSameIndex( 110 secondPloop, firstPloop, secondToFirstPloopIndices)); 111 } 112 113 static bool 114 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, 115 const BlockAndValueMapping &firstToSecondPloopIndices) { 116 return !hasNestedParallelOp(firstPloop) && 117 !hasNestedParallelOp(secondPloop) && 118 equalIterationSpaces(firstPloop, secondPloop) && 119 succeeded(verifyDependencies(firstPloop, secondPloop, 120 firstToSecondPloopIndices)); 121 } 122 123 /// Prepends operations of firstPloop's body into secondPloop's body. 124 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, 125 OpBuilder b) { 126 BlockAndValueMapping firstToSecondPloopIndices; 127 firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), 128 secondPloop.getBody()->getArguments()); 129 130 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices)) 131 return; 132 133 b.setInsertionPointToStart(secondPloop.getBody()); 134 for (auto &op : firstPloop.getBody()->without_terminator()) 135 b.clone(op, firstToSecondPloopIndices); 136 firstPloop.erase(); 137 } 138 139 void mlir::scf::naivelyFuseParallelOps(Region ®ion) { 140 OpBuilder b(region); 141 // Consider every single block and attempt to fuse adjacent loops. 142 for (auto &block : region) { 143 SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}}; 144 // Not using `walk()` to traverse only top-level parallel loops and also 145 // make sure that there are no side-effecting ops between the parallel 146 // loops. 147 bool noSideEffects = true; 148 for (auto &op : block) { 149 if (auto ploop = dyn_cast<ParallelOp>(op)) { 150 if (noSideEffects) { 151 ploopChains.back().push_back(ploop); 152 } else { 153 ploopChains.push_back({ploop}); 154 noSideEffects = true; 155 } 156 continue; 157 } 158 // TODO: Handle region side effects properly. 159 noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0; 160 } 161 for (ArrayRef<ParallelOp> ploops : ploopChains) { 162 for (int i = 0, e = ploops.size(); i + 1 < e; ++i) 163 fuseIfLegal(ploops[i], ploops[i + 1], b); 164 } 165 } 166 } 167 168 namespace { 169 struct ParallelLoopFusion 170 : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> { 171 void runOnOperation() override { 172 getOperation()->walk([&](Operation *child) { 173 for (Region ®ion : child->getRegions()) 174 naivelyFuseParallelOps(region); 175 }); 176 } 177 }; 178 } // namespace 179 180 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() { 181 return std::make_unique<ParallelLoopFusion>(); 182 } 183