1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion on parallel loops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/Passes.h" 14 15 #include "mlir/Analysis/AliasAnalysis.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/IR/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/OpDefinition.h" 22 #include "mlir/IR/OperationSupport.h" 23 #include "mlir/Interfaces/SideEffectInterfaces.h" 24 25 namespace mlir { 26 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION 27 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 28 } // namespace mlir 29 30 using namespace mlir; 31 using namespace mlir::scf; 32 33 /// Verify there are no nested ParallelOps. 34 static bool hasNestedParallelOp(ParallelOp ploop) { 35 auto walkResult = 36 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); 37 return walkResult.wasInterrupted(); 38 } 39 40 /// Verify equal iteration spaces. 41 static bool equalIterationSpaces(ParallelOp firstPloop, 42 ParallelOp secondPloop) { 43 if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) 44 return false; 45 46 auto matchOperands = [&](const OperandRange &lhs, 47 const OperandRange &rhs) -> bool { 48 // TODO: Extend this to support aliases and equal constants. 49 return std::equal(lhs.begin(), lhs.end(), rhs.begin()); 50 }; 51 return matchOperands(firstPloop.getLowerBound(), 52 secondPloop.getLowerBound()) && 53 matchOperands(firstPloop.getUpperBound(), 54 secondPloop.getUpperBound()) && 55 matchOperands(firstPloop.getStep(), secondPloop.getStep()); 56 } 57 58 /// Checks if the parallel loops have mixed access to the same buffers. Returns 59 /// `true` if the first parallel loop writes to the same indices that the second 60 /// loop reads. 61 static bool haveNoReadsAfterWriteExceptSameIndex( 62 ParallelOp firstPloop, ParallelOp secondPloop, 63 const IRMapping &firstToSecondPloopIndices, 64 llvm::function_ref<bool(Value, Value)> mayAlias) { 65 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores; 66 SmallVector<Value> bufferStoresVec; 67 firstPloop.getBody()->walk([&](memref::StoreOp store) { 68 bufferStores[store.getMemRef()].push_back(store.getIndices()); 69 bufferStoresVec.emplace_back(store.getMemRef()); 70 }); 71 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { 72 Value loadMem = load.getMemRef(); 73 // Stop if the memref is defined in secondPloop body. Careful alias analysis 74 // is needed. 75 auto *memrefDef = loadMem.getDefiningOp(); 76 if (memrefDef && memrefDef->getBlock() == load->getBlock()) 77 return WalkResult::interrupt(); 78 79 for (Value store : bufferStoresVec) 80 if (store != loadMem && mayAlias(store, loadMem)) 81 return WalkResult::interrupt(); 82 83 auto write = bufferStores.find(loadMem); 84 if (write == bufferStores.end()) 85 return WalkResult::advance(); 86 87 // Check that at last one store was retrieved 88 if (!write->second.size()) 89 return WalkResult::interrupt(); 90 91 auto storeIndices = write->second.front(); 92 93 // Multiple writes to the same memref are allowed only on the same indices 94 for (const auto &othStoreIndices : write->second) { 95 if (othStoreIndices != storeIndices) 96 return WalkResult::interrupt(); 97 } 98 99 // Check that the load indices of secondPloop coincide with store indices of 100 // firstPloop for the same memrefs. 101 auto loadIndices = load.getIndices(); 102 if (storeIndices.size() != loadIndices.size()) 103 return WalkResult::interrupt(); 104 for (int i = 0, e = storeIndices.size(); i < e; ++i) { 105 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != 106 loadIndices[i]) { 107 auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); 108 auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); 109 if (storeIndexDefOp && loadIndexDefOp) { 110 if (!isMemoryEffectFree(storeIndexDefOp)) 111 return WalkResult::interrupt(); 112 if (!isMemoryEffectFree(loadIndexDefOp)) 113 return WalkResult::interrupt(); 114 if (!OperationEquivalence::isEquivalentTo( 115 storeIndexDefOp, loadIndexDefOp, 116 [&](Value storeIndex, Value loadIndex) { 117 if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) != 118 firstToSecondPloopIndices.lookupOrDefault(loadIndex)) 119 return failure(); 120 else 121 return success(); 122 }, 123 /*markEquivalent=*/nullptr, 124 OperationEquivalence::Flags::IgnoreLocations)) { 125 return WalkResult::interrupt(); 126 } 127 } else 128 return WalkResult::interrupt(); 129 } 130 } 131 return WalkResult::advance(); 132 }); 133 return !walkResult.wasInterrupted(); 134 } 135 136 /// Analyzes dependencies in the most primitive way by checking simple read and 137 /// write patterns. 138 static LogicalResult 139 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, 140 const IRMapping &firstToSecondPloopIndices, 141 llvm::function_ref<bool(Value, Value)> mayAlias) { 142 if (!haveNoReadsAfterWriteExceptSameIndex( 143 firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) 144 return failure(); 145 146 IRMapping secondToFirstPloopIndices; 147 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), 148 firstPloop.getBody()->getArguments()); 149 return success(haveNoReadsAfterWriteExceptSameIndex( 150 secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); 151 } 152 153 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, 154 const IRMapping &firstToSecondPloopIndices, 155 llvm::function_ref<bool(Value, Value)> mayAlias) { 156 return !hasNestedParallelOp(firstPloop) && 157 !hasNestedParallelOp(secondPloop) && 158 equalIterationSpaces(firstPloop, secondPloop) && 159 succeeded(verifyDependencies(firstPloop, secondPloop, 160 firstToSecondPloopIndices, mayAlias)); 161 } 162 163 /// Prepends operations of firstPloop's body into secondPloop's body. 164 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, 165 OpBuilder b, 166 llvm::function_ref<bool(Value, Value)> mayAlias) { 167 IRMapping firstToSecondPloopIndices; 168 firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), 169 secondPloop.getBody()->getArguments()); 170 171 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, 172 mayAlias)) 173 return; 174 175 b.setInsertionPointToStart(secondPloop.getBody()); 176 for (auto &op : firstPloop.getBody()->without_terminator()) 177 b.clone(op, firstToSecondPloopIndices); 178 firstPloop.erase(); 179 } 180 181 void mlir::scf::naivelyFuseParallelOps( 182 Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) { 183 OpBuilder b(region); 184 // Consider every single block and attempt to fuse adjacent loops. 185 for (auto &block : region) { 186 SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}}; 187 // Not using `walk()` to traverse only top-level parallel loops and also 188 // make sure that there are no side-effecting ops between the parallel 189 // loops. 190 bool noSideEffects = true; 191 for (auto &op : block) { 192 if (auto ploop = dyn_cast<ParallelOp>(op)) { 193 if (noSideEffects) { 194 ploopChains.back().push_back(ploop); 195 } else { 196 ploopChains.push_back({ploop}); 197 noSideEffects = true; 198 } 199 continue; 200 } 201 // TODO: Handle region side effects properly. 202 noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0; 203 } 204 for (ArrayRef<ParallelOp> ploops : ploopChains) { 205 for (int i = 0, e = ploops.size(); i + 1 < e; ++i) 206 fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); 207 } 208 } 209 } 210 211 namespace { 212 struct ParallelLoopFusion 213 : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> { 214 void runOnOperation() override { 215 auto &AA = getAnalysis<AliasAnalysis>(); 216 217 auto mayAlias = [&](Value val1, Value val2) -> bool { 218 return !AA.alias(val1, val2).isNo(); 219 }; 220 221 getOperation()->walk([&](Operation *child) { 222 for (Region ®ion : child->getRegions()) 223 naivelyFuseParallelOps(region, mayAlias); 224 }); 225 } 226 }; 227 } // namespace 228 229 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() { 230 return std::make_unique<ParallelLoopFusion>(); 231 } 232