1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion on parallel loops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/Passes.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/Dialect/SCF/Transforms.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/OpDefinition.h" 22 23 using namespace mlir; 24 using namespace mlir::scf; 25 26 /// Verify there are no nested ParallelOps. 27 static bool hasNestedParallelOp(ParallelOp ploop) { 28 auto walkResult = 29 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); 30 return walkResult.wasInterrupted(); 31 } 32 33 /// Verify equal iteration spaces. 34 static bool equalIterationSpaces(ParallelOp firstPloop, 35 ParallelOp secondPloop) { 36 if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) 37 return false; 38 39 auto matchOperands = [&](const OperandRange &lhs, 40 const OperandRange &rhs) -> bool { 41 // TODO: Extend this to support aliases and equal constants. 42 return std::equal(lhs.begin(), lhs.end(), rhs.begin()); 43 }; 44 return matchOperands(firstPloop.getLowerBound(), 45 secondPloop.getLowerBound()) && 46 matchOperands(firstPloop.getUpperBound(), 47 secondPloop.getUpperBound()) && 48 matchOperands(firstPloop.getStep(), secondPloop.getStep()); 49 } 50 51 /// Checks if the parallel loops have mixed access to the same buffers. Returns 52 /// `true` if the first parallel loop writes to the same indices that the second 53 /// loop reads. 54 static bool haveNoReadsAfterWriteExceptSameIndex( 55 ParallelOp firstPloop, ParallelOp secondPloop, 56 const BlockAndValueMapping &firstToSecondPloopIndices) { 57 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores; 58 firstPloop.getBody()->walk([&](memref::StoreOp store) { 59 bufferStores[store.getMemRef()].push_back(store.indices()); 60 }); 61 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { 62 // Stop if the memref is defined in secondPloop body. Careful alias analysis 63 // is needed. 64 auto *memrefDef = load.getMemRef().getDefiningOp(); 65 if (memrefDef && memrefDef->getBlock() == load->getBlock()) 66 return WalkResult::interrupt(); 67 68 auto write = bufferStores.find(load.getMemRef()); 69 if (write == bufferStores.end()) 70 return WalkResult::advance(); 71 72 // Allow only single write access per buffer. 73 if (write->second.size() != 1) 74 return WalkResult::interrupt(); 75 76 // Check that the load indices of secondPloop coincide with store indices of 77 // firstPloop for the same memrefs. 78 auto storeIndices = write->second.front(); 79 auto loadIndices = load.indices(); 80 if (storeIndices.size() != loadIndices.size()) 81 return WalkResult::interrupt(); 82 for (int i = 0, e = storeIndices.size(); i < e; ++i) { 83 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != 84 loadIndices[i]) 85 return WalkResult::interrupt(); 86 } 87 return WalkResult::advance(); 88 }); 89 return !walkResult.wasInterrupted(); 90 } 91 92 /// Analyzes dependencies in the most primitive way by checking simple read and 93 /// write patterns. 94 static LogicalResult 95 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, 96 const BlockAndValueMapping &firstToSecondPloopIndices) { 97 if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop, 98 firstToSecondPloopIndices)) 99 return failure(); 100 101 BlockAndValueMapping secondToFirstPloopIndices; 102 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), 103 firstPloop.getBody()->getArguments()); 104 return success(haveNoReadsAfterWriteExceptSameIndex( 105 secondPloop, firstPloop, secondToFirstPloopIndices)); 106 } 107 108 static bool 109 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, 110 const BlockAndValueMapping &firstToSecondPloopIndices) { 111 return !hasNestedParallelOp(firstPloop) && 112 !hasNestedParallelOp(secondPloop) && 113 equalIterationSpaces(firstPloop, secondPloop) && 114 succeeded(verifyDependencies(firstPloop, secondPloop, 115 firstToSecondPloopIndices)); 116 } 117 118 /// Prepends operations of firstPloop's body into secondPloop's body. 119 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, 120 OpBuilder b) { 121 BlockAndValueMapping firstToSecondPloopIndices; 122 firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), 123 secondPloop.getBody()->getArguments()); 124 125 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices)) 126 return; 127 128 b.setInsertionPointToStart(secondPloop.getBody()); 129 for (auto &op : firstPloop.getBody()->without_terminator()) 130 b.clone(op, firstToSecondPloopIndices); 131 firstPloop.erase(); 132 } 133 134 void mlir::scf::naivelyFuseParallelOps(Region ®ion) { 135 OpBuilder b(region); 136 // Consider every single block and attempt to fuse adjacent loops. 137 for (auto &block : region) { 138 SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}}; 139 // Not using `walk()` to traverse only top-level parallel loops and also 140 // make sure that there are no side-effecting ops between the parallel 141 // loops. 142 bool noSideEffects = true; 143 for (auto &op : block) { 144 if (auto ploop = dyn_cast<ParallelOp>(op)) { 145 if (noSideEffects) { 146 ploopChains.back().push_back(ploop); 147 } else { 148 ploopChains.push_back({ploop}); 149 noSideEffects = true; 150 } 151 continue; 152 } 153 // TODO: Handle region side effects properly. 154 noSideEffects &= 155 MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0; 156 } 157 for (ArrayRef<ParallelOp> ploops : ploopChains) { 158 for (int i = 0, e = ploops.size(); i + 1 < e; ++i) 159 fuseIfLegal(ploops[i], ploops[i + 1], b); 160 } 161 } 162 } 163 164 namespace { 165 struct ParallelLoopFusion 166 : public SCFParallelLoopFusionBase<ParallelLoopFusion> { 167 void runOnOperation() override { 168 getOperation()->walk([&](Operation *child) { 169 for (Region ®ion : child->getRegions()) 170 naivelyFuseParallelOps(region); 171 }); 172 } 173 }; 174 } // namespace 175 176 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() { 177 return std::make_unique<ParallelLoopFusion>(); 178 } 179