1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion on parallel loops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/Passes.h" 14 15 #include "mlir/Analysis/AliasAnalysis.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/IR/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/IRMapping.h" 22 #include "mlir/IR/OpDefinition.h" 23 #include "mlir/IR/OperationSupport.h" 24 #include "mlir/Interfaces/SideEffectInterfaces.h" 25 26 namespace mlir { 27 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION 28 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 29 } // namespace mlir 30 31 using namespace mlir; 32 using namespace mlir::scf; 33 34 /// Verify there are no nested ParallelOps. 35 static bool hasNestedParallelOp(ParallelOp ploop) { 36 auto walkResult = 37 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); 38 return walkResult.wasInterrupted(); 39 } 40 41 /// Checks if the parallel loops have mixed access to the same buffers. Returns 42 /// `true` if the first parallel loop writes to the same indices that the second 43 /// loop reads. 44 static bool haveNoReadsAfterWriteExceptSameIndex( 45 ParallelOp firstPloop, ParallelOp secondPloop, 46 const IRMapping &firstToSecondPloopIndices, 47 llvm::function_ref<bool(Value, Value)> mayAlias) { 48 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores; 49 SmallVector<Value> bufferStoresVec; 50 firstPloop.getBody()->walk([&](memref::StoreOp store) { 51 bufferStores[store.getMemRef()].push_back(store.getIndices()); 52 bufferStoresVec.emplace_back(store.getMemRef()); 53 }); 54 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { 55 Value loadMem = load.getMemRef(); 56 // Stop if the memref is defined in secondPloop body. Careful alias analysis 57 // is needed. 58 auto *memrefDef = loadMem.getDefiningOp(); 59 if (memrefDef && memrefDef->getBlock() == load->getBlock()) 60 return WalkResult::interrupt(); 61 62 for (Value store : bufferStoresVec) 63 if (store != loadMem && mayAlias(store, loadMem)) 64 return WalkResult::interrupt(); 65 66 auto write = bufferStores.find(loadMem); 67 if (write == bufferStores.end()) 68 return WalkResult::advance(); 69 70 // Check that at last one store was retrieved 71 if (!write->second.size()) 72 return WalkResult::interrupt(); 73 74 auto storeIndices = write->second.front(); 75 76 // Multiple writes to the same memref are allowed only on the same indices 77 for (const auto &othStoreIndices : write->second) { 78 if (othStoreIndices != storeIndices) 79 return WalkResult::interrupt(); 80 } 81 82 // Check that the load indices of secondPloop coincide with store indices of 83 // firstPloop for the same memrefs. 84 auto loadIndices = load.getIndices(); 85 if (storeIndices.size() != loadIndices.size()) 86 return WalkResult::interrupt(); 87 for (int i = 0, e = storeIndices.size(); i < e; ++i) { 88 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != 89 loadIndices[i]) { 90 auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); 91 auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); 92 if (storeIndexDefOp && loadIndexDefOp) { 93 if (!isMemoryEffectFree(storeIndexDefOp)) 94 return WalkResult::interrupt(); 95 if (!isMemoryEffectFree(loadIndexDefOp)) 96 return WalkResult::interrupt(); 97 if (!OperationEquivalence::isEquivalentTo( 98 storeIndexDefOp, loadIndexDefOp, 99 [&](Value storeIndex, Value loadIndex) { 100 if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) != 101 firstToSecondPloopIndices.lookupOrDefault(loadIndex)) 102 return failure(); 103 else 104 return success(); 105 }, 106 /*markEquivalent=*/nullptr, 107 OperationEquivalence::Flags::IgnoreLocations)) { 108 return WalkResult::interrupt(); 109 } 110 } else 111 return WalkResult::interrupt(); 112 } 113 } 114 return WalkResult::advance(); 115 }); 116 return !walkResult.wasInterrupted(); 117 } 118 119 /// Analyzes dependencies in the most primitive way by checking simple read and 120 /// write patterns. 121 static LogicalResult 122 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, 123 const IRMapping &firstToSecondPloopIndices, 124 llvm::function_ref<bool(Value, Value)> mayAlias) { 125 if (!haveNoReadsAfterWriteExceptSameIndex( 126 firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) 127 return failure(); 128 129 IRMapping secondToFirstPloopIndices; 130 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), 131 firstPloop.getBody()->getArguments()); 132 return success(haveNoReadsAfterWriteExceptSameIndex( 133 secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); 134 } 135 136 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, 137 const IRMapping &firstToSecondPloopIndices, 138 llvm::function_ref<bool(Value, Value)> mayAlias) { 139 Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark); 140 return !hasNestedParallelOp(firstPloop) && 141 !hasNestedParallelOp(secondPloop) && 142 checkFusionStructuralLegality(firstPloop, secondPloop, diag) && 143 succeeded(verifyDependencies(firstPloop, secondPloop, 144 firstToSecondPloopIndices, mayAlias)); 145 } 146 147 /// Prepends operations of firstPloop's body into secondPloop's body. 148 /// Updates secondPloop with new loop. 149 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, 150 OpBuilder builder, 151 llvm::function_ref<bool(Value, Value)> mayAlias) { 152 Block *block1 = firstPloop.getBody(); 153 Block *block2 = secondPloop.getBody(); 154 IRMapping firstToSecondPloopIndices; 155 firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments()); 156 157 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, 158 mayAlias)) 159 return; 160 161 IRRewriter rewriter(builder); 162 secondPloop = mlir::fuseIndependentSiblingParallelLoops( 163 firstPloop, secondPloop, rewriter); 164 } 165 166 void mlir::scf::naivelyFuseParallelOps( 167 Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) { 168 OpBuilder b(region); 169 // Consider every single block and attempt to fuse adjacent loops. 170 SmallVector<SmallVector<ParallelOp>, 1> ploopChains; 171 for (auto &block : region) { 172 ploopChains.clear(); 173 ploopChains.push_back({}); 174 175 // Not using `walk()` to traverse only top-level parallel loops and also 176 // make sure that there are no side-effecting ops between the parallel 177 // loops. 178 bool noSideEffects = true; 179 for (auto &op : block) { 180 if (auto ploop = dyn_cast<ParallelOp>(op)) { 181 if (noSideEffects) { 182 ploopChains.back().push_back(ploop); 183 } else { 184 ploopChains.push_back({ploop}); 185 noSideEffects = true; 186 } 187 continue; 188 } 189 // TODO: Handle region side effects properly. 190 noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0; 191 } 192 for (MutableArrayRef<ParallelOp> ploops : ploopChains) { 193 for (int i = 0, e = ploops.size(); i + 1 < e; ++i) 194 fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); 195 } 196 } 197 } 198 199 namespace { 200 struct ParallelLoopFusion 201 : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> { 202 void runOnOperation() override { 203 auto &AA = getAnalysis<AliasAnalysis>(); 204 205 auto mayAlias = [&](Value val1, Value val2) -> bool { 206 return !AA.alias(val1, val2).isNo(); 207 }; 208 209 getOperation()->walk([&](Operation *child) { 210 for (Region ®ion : child->getRegions()) 211 naivelyFuseParallelOps(region, mayAlias); 212 }); 213 } 214 }; 215 } // namespace 216 217 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() { 218 return std::make_unique<ParallelLoopFusion>(); 219 } 220