1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion on parallel loops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/Passes.h" 14 15 #include "mlir/Analysis/AliasAnalysis.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/IR/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/OpDefinition.h" 22 #include "mlir/Interfaces/SideEffectInterfaces.h" 23 24 namespace mlir { 25 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION 26 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 27 } // namespace mlir 28 29 using namespace mlir; 30 using namespace mlir::scf; 31 32 /// Verify there are no nested ParallelOps. 33 static bool hasNestedParallelOp(ParallelOp ploop) { 34 auto walkResult = 35 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); 36 return walkResult.wasInterrupted(); 37 } 38 39 /// Verify equal iteration spaces. 40 static bool equalIterationSpaces(ParallelOp firstPloop, 41 ParallelOp secondPloop) { 42 if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) 43 return false; 44 45 auto matchOperands = [&](const OperandRange &lhs, 46 const OperandRange &rhs) -> bool { 47 // TODO: Extend this to support aliases and equal constants. 48 return std::equal(lhs.begin(), lhs.end(), rhs.begin()); 49 }; 50 return matchOperands(firstPloop.getLowerBound(), 51 secondPloop.getLowerBound()) && 52 matchOperands(firstPloop.getUpperBound(), 53 secondPloop.getUpperBound()) && 54 matchOperands(firstPloop.getStep(), secondPloop.getStep()); 55 } 56 57 /// Checks if the parallel loops have mixed access to the same buffers. Returns 58 /// `true` if the first parallel loop writes to the same indices that the second 59 /// loop reads. 60 static bool haveNoReadsAfterWriteExceptSameIndex( 61 ParallelOp firstPloop, ParallelOp secondPloop, 62 const IRMapping &firstToSecondPloopIndices, 63 llvm::function_ref<bool(Value, Value)> mayAlias) { 64 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores; 65 SmallVector<Value> bufferStoresVec; 66 firstPloop.getBody()->walk([&](memref::StoreOp store) { 67 bufferStores[store.getMemRef()].push_back(store.getIndices()); 68 bufferStoresVec.emplace_back(store.getMemRef()); 69 }); 70 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { 71 Value loadMem = load.getMemRef(); 72 // Stop if the memref is defined in secondPloop body. Careful alias analysis 73 // is needed. 74 auto *memrefDef = loadMem.getDefiningOp(); 75 if (memrefDef && memrefDef->getBlock() == load->getBlock()) 76 return WalkResult::interrupt(); 77 78 for (Value store : bufferStoresVec) 79 if (store != loadMem && mayAlias(store, loadMem)) 80 return WalkResult::interrupt(); 81 82 auto write = bufferStores.find(loadMem); 83 if (write == bufferStores.end()) 84 return WalkResult::advance(); 85 86 // Check that at last one store was retrieved 87 if (!write->second.size()) 88 return WalkResult::interrupt(); 89 90 auto storeIndices = write->second.front(); 91 92 // Multiple writes to the same memref are allowed only on the same indices 93 for (const auto &othStoreIndices : write->second) { 94 if (othStoreIndices != storeIndices) 95 return WalkResult::interrupt(); 96 } 97 98 // Check that the load indices of secondPloop coincide with store indices of 99 // firstPloop for the same memrefs. 100 auto loadIndices = load.getIndices(); 101 if (storeIndices.size() != loadIndices.size()) 102 return WalkResult::interrupt(); 103 for (int i = 0, e = storeIndices.size(); i < e; ++i) { 104 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != 105 loadIndices[i]) 106 return WalkResult::interrupt(); 107 } 108 return WalkResult::advance(); 109 }); 110 return !walkResult.wasInterrupted(); 111 } 112 113 /// Analyzes dependencies in the most primitive way by checking simple read and 114 /// write patterns. 115 static LogicalResult 116 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, 117 const IRMapping &firstToSecondPloopIndices, 118 llvm::function_ref<bool(Value, Value)> mayAlias) { 119 if (!haveNoReadsAfterWriteExceptSameIndex( 120 firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) 121 return failure(); 122 123 IRMapping secondToFirstPloopIndices; 124 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), 125 firstPloop.getBody()->getArguments()); 126 return success(haveNoReadsAfterWriteExceptSameIndex( 127 secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); 128 } 129 130 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, 131 const IRMapping &firstToSecondPloopIndices, 132 llvm::function_ref<bool(Value, Value)> mayAlias) { 133 return !hasNestedParallelOp(firstPloop) && 134 !hasNestedParallelOp(secondPloop) && 135 equalIterationSpaces(firstPloop, secondPloop) && 136 succeeded(verifyDependencies(firstPloop, secondPloop, 137 firstToSecondPloopIndices, mayAlias)); 138 } 139 140 /// Prepends operations of firstPloop's body into secondPloop's body. 141 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, 142 OpBuilder b, 143 llvm::function_ref<bool(Value, Value)> mayAlias) { 144 IRMapping firstToSecondPloopIndices; 145 firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), 146 secondPloop.getBody()->getArguments()); 147 148 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, 149 mayAlias)) 150 return; 151 152 b.setInsertionPointToStart(secondPloop.getBody()); 153 for (auto &op : firstPloop.getBody()->without_terminator()) 154 b.clone(op, firstToSecondPloopIndices); 155 firstPloop.erase(); 156 } 157 158 void mlir::scf::naivelyFuseParallelOps( 159 Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) { 160 OpBuilder b(region); 161 // Consider every single block and attempt to fuse adjacent loops. 162 for (auto &block : region) { 163 SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}}; 164 // Not using `walk()` to traverse only top-level parallel loops and also 165 // make sure that there are no side-effecting ops between the parallel 166 // loops. 167 bool noSideEffects = true; 168 for (auto &op : block) { 169 if (auto ploop = dyn_cast<ParallelOp>(op)) { 170 if (noSideEffects) { 171 ploopChains.back().push_back(ploop); 172 } else { 173 ploopChains.push_back({ploop}); 174 noSideEffects = true; 175 } 176 continue; 177 } 178 // TODO: Handle region side effects properly. 179 noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0; 180 } 181 for (ArrayRef<ParallelOp> ploops : ploopChains) { 182 for (int i = 0, e = ploops.size(); i + 1 < e; ++i) 183 fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); 184 } 185 } 186 } 187 188 namespace { 189 struct ParallelLoopFusion 190 : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> { 191 void runOnOperation() override { 192 auto &AA = getAnalysis<AliasAnalysis>(); 193 194 auto mayAlias = [&](Value val1, Value val2) -> bool { 195 return !AA.alias(val1, val2).isNo(); 196 }; 197 198 getOperation()->walk([&](Operation *child) { 199 for (Region ®ion : child->getRegions()) 200 naivelyFuseParallelOps(region, mayAlias); 201 }); 202 } 203 }; 204 } // namespace 205 206 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() { 207 return std::make_unique<ParallelLoopFusion>(); 208 } 209