1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion on parallel loops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/Passes.h" 14 15 #include "mlir/Analysis/AliasAnalysis.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/IR/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/OpDefinition.h" 22 #include "mlir/IR/OperationSupport.h" 23 #include "mlir/Interfaces/SideEffectInterfaces.h" 24 25 namespace mlir { 26 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION 27 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 28 } // namespace mlir 29 30 using namespace mlir; 31 using namespace mlir::scf; 32 33 /// Verify there are no nested ParallelOps. 34 static bool hasNestedParallelOp(ParallelOp ploop) { 35 auto walkResult = 36 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); 37 return walkResult.wasInterrupted(); 38 } 39 40 /// Verify equal iteration spaces. 41 static bool equalIterationSpaces(ParallelOp firstPloop, 42 ParallelOp secondPloop) { 43 if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) 44 return false; 45 46 auto matchOperands = [&](const OperandRange &lhs, 47 const OperandRange &rhs) -> bool { 48 // TODO: Extend this to support aliases and equal constants. 49 return std::equal(lhs.begin(), lhs.end(), rhs.begin()); 50 }; 51 return matchOperands(firstPloop.getLowerBound(), 52 secondPloop.getLowerBound()) && 53 matchOperands(firstPloop.getUpperBound(), 54 secondPloop.getUpperBound()) && 55 matchOperands(firstPloop.getStep(), secondPloop.getStep()); 56 } 57 58 /// Checks if the parallel loops have mixed access to the same buffers. Returns 59 /// `true` if the first parallel loop writes to the same indices that the second 60 /// loop reads. 61 static bool haveNoReadsAfterWriteExceptSameIndex( 62 ParallelOp firstPloop, ParallelOp secondPloop, 63 const IRMapping &firstToSecondPloopIndices, 64 llvm::function_ref<bool(Value, Value)> mayAlias) { 65 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores; 66 SmallVector<Value> bufferStoresVec; 67 firstPloop.getBody()->walk([&](memref::StoreOp store) { 68 bufferStores[store.getMemRef()].push_back(store.getIndices()); 69 bufferStoresVec.emplace_back(store.getMemRef()); 70 }); 71 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { 72 Value loadMem = load.getMemRef(); 73 // Stop if the memref is defined in secondPloop body. Careful alias analysis 74 // is needed. 75 auto *memrefDef = loadMem.getDefiningOp(); 76 if (memrefDef && memrefDef->getBlock() == load->getBlock()) 77 return WalkResult::interrupt(); 78 79 for (Value store : bufferStoresVec) 80 if (store != loadMem && mayAlias(store, loadMem)) 81 return WalkResult::interrupt(); 82 83 auto write = bufferStores.find(loadMem); 84 if (write == bufferStores.end()) 85 return WalkResult::advance(); 86 87 // Check that at last one store was retrieved 88 if (write->second.empty()) 89 return WalkResult::interrupt(); 90 91 auto storeIndices = write->second.front(); 92 93 // Multiple writes to the same memref are allowed only on the same indices 94 for (const auto &othStoreIndices : write->second) { 95 if (othStoreIndices != storeIndices) 96 return WalkResult::interrupt(); 97 } 98 99 // Check that the load indices of secondPloop coincide with store indices of 100 // firstPloop for the same memrefs. 101 auto loadIndices = load.getIndices(); 102 if (storeIndices.size() != loadIndices.size()) 103 return WalkResult::interrupt(); 104 for (int i = 0, e = storeIndices.size(); i < e; ++i) { 105 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != 106 loadIndices[i]) { 107 auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); 108 auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); 109 if (storeIndexDefOp && loadIndexDefOp) { 110 if (!isMemoryEffectFree(storeIndexDefOp)) 111 return WalkResult::interrupt(); 112 if (!isMemoryEffectFree(loadIndexDefOp)) 113 return WalkResult::interrupt(); 114 if (!OperationEquivalence::isEquivalentTo( 115 storeIndexDefOp, loadIndexDefOp, 116 [&](Value storeIndex, Value loadIndex) { 117 if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) != 118 firstToSecondPloopIndices.lookupOrDefault(loadIndex)) 119 return failure(); 120 else 121 return success(); 122 }, 123 /*markEquivalent=*/nullptr, 124 OperationEquivalence::Flags::IgnoreLocations)) { 125 return WalkResult::interrupt(); 126 } 127 } else 128 return WalkResult::interrupt(); 129 } 130 } 131 return WalkResult::advance(); 132 }); 133 return !walkResult.wasInterrupted(); 134 } 135 136 /// Analyzes dependencies in the most primitive way by checking simple read and 137 /// write patterns. 138 static LogicalResult 139 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, 140 const IRMapping &firstToSecondPloopIndices, 141 llvm::function_ref<bool(Value, Value)> mayAlias) { 142 if (!haveNoReadsAfterWriteExceptSameIndex( 143 firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) 144 return failure(); 145 146 IRMapping secondToFirstPloopIndices; 147 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), 148 firstPloop.getBody()->getArguments()); 149 return success(haveNoReadsAfterWriteExceptSameIndex( 150 secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); 151 } 152 153 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, 154 const IRMapping &firstToSecondPloopIndices, 155 llvm::function_ref<bool(Value, Value)> mayAlias) { 156 return !hasNestedParallelOp(firstPloop) && 157 !hasNestedParallelOp(secondPloop) && 158 equalIterationSpaces(firstPloop, secondPloop) && 159 succeeded(verifyDependencies(firstPloop, secondPloop, 160 firstToSecondPloopIndices, mayAlias)); 161 } 162 163 /// Prepends operations of firstPloop's body into secondPloop's body. 164 /// Updates secondPloop with new loop. 165 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, 166 OpBuilder builder, 167 llvm::function_ref<bool(Value, Value)> mayAlias) { 168 Block *block1 = firstPloop.getBody(); 169 Block *block2 = secondPloop.getBody(); 170 IRMapping firstToSecondPloopIndices; 171 firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments()); 172 173 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, 174 mayAlias)) 175 return; 176 177 DominanceInfo dom; 178 // We are fusing first loop into second, make sure there are no users of the 179 // first loop results between loops. 180 for (Operation *user : firstPloop->getUsers()) 181 if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) 182 return; 183 184 ValueRange inits1 = firstPloop.getInitVals(); 185 ValueRange inits2 = secondPloop.getInitVals(); 186 187 SmallVector<Value> newInitVars(inits1.begin(), inits1.end()); 188 newInitVars.append(inits2.begin(), inits2.end()); 189 190 IRRewriter b(builder); 191 b.setInsertionPoint(secondPloop); 192 auto newSecondPloop = b.create<ParallelOp>( 193 secondPloop.getLoc(), secondPloop.getLowerBound(), 194 secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); 195 196 Block *newBlock = newSecondPloop.getBody(); 197 auto term1 = cast<ReduceOp>(block1->getTerminator()); 198 auto term2 = cast<ReduceOp>(block2->getTerminator()); 199 200 b.inlineBlockBefore(block2, newBlock, newBlock->begin(), 201 newBlock->getArguments()); 202 b.inlineBlockBefore(block1, newBlock, newBlock->begin(), 203 newBlock->getArguments()); 204 205 ValueRange results = newSecondPloop.getResults(); 206 if (!results.empty()) { 207 b.setInsertionPointToEnd(newBlock); 208 209 ValueRange reduceArgs1 = term1.getOperands(); 210 ValueRange reduceArgs2 = term2.getOperands(); 211 SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); 212 newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); 213 214 auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs); 215 216 for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>( 217 term1.getReductions(), term2.getReductions()))) { 218 Block &oldRedBlock = reg.front(); 219 Block &newRedBlock = newReduceOp.getReductions()[i].front(); 220 b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), 221 newRedBlock.getArguments()); 222 } 223 224 firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); 225 secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); 226 } 227 term1->erase(); 228 term2->erase(); 229 firstPloop.erase(); 230 secondPloop.erase(); 231 secondPloop = newSecondPloop; 232 } 233 234 void mlir::scf::naivelyFuseParallelOps( 235 Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) { 236 OpBuilder b(region); 237 // Consider every single block and attempt to fuse adjacent loops. 238 SmallVector<SmallVector<ParallelOp>, 1> ploopChains; 239 for (auto &block : region) { 240 ploopChains.clear(); 241 ploopChains.push_back({}); 242 243 // Not using `walk()` to traverse only top-level parallel loops and also 244 // make sure that there are no side-effecting ops between the parallel 245 // loops. 246 bool noSideEffects = true; 247 for (auto &op : block) { 248 if (auto ploop = dyn_cast<ParallelOp>(op)) { 249 if (noSideEffects) { 250 ploopChains.back().push_back(ploop); 251 } else { 252 ploopChains.push_back({ploop}); 253 noSideEffects = true; 254 } 255 continue; 256 } 257 // TODO: Handle region side effects properly. 258 noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0; 259 } 260 for (MutableArrayRef<ParallelOp> ploops : ploopChains) { 261 for (int i = 0, e = ploops.size(); i + 1 < e; ++i) 262 fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); 263 } 264 } 265 } 266 267 namespace { 268 struct ParallelLoopFusion 269 : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> { 270 void runOnOperation() override { 271 auto &AA = getAnalysis<AliasAnalysis>(); 272 273 auto mayAlias = [&](Value val1, Value val2) -> bool { 274 return !AA.alias(val1, val2).isNo(); 275 }; 276 277 getOperation()->walk([&](Operation *child) { 278 for (Region ®ion : child->getRegions()) 279 naivelyFuseParallelOps(region, mayAlias); 280 }); 281 } 282 }; 283 } // namespace 284 285 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() { 286 return std::make_unique<ParallelLoopFusion>(); 287 } 288