1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion on parallel loops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/SCF/Transforms/Passes.h" 17 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/OpDefinition.h" 21 22 using namespace mlir; 23 using namespace mlir::scf; 24 25 /// Verify there are no nested ParallelOps. 26 static bool hasNestedParallelOp(ParallelOp ploop) { 27 auto walkResult = 28 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); 29 return walkResult.wasInterrupted(); 30 } 31 32 /// Verify equal iteration spaces. 33 static bool equalIterationSpaces(ParallelOp firstPloop, 34 ParallelOp secondPloop) { 35 if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) 36 return false; 37 38 auto matchOperands = [&](const OperandRange &lhs, 39 const OperandRange &rhs) -> bool { 40 // TODO: Extend this to support aliases and equal constants. 41 return std::equal(lhs.begin(), lhs.end(), rhs.begin()); 42 }; 43 return matchOperands(firstPloop.getLowerBound(), 44 secondPloop.getLowerBound()) && 45 matchOperands(firstPloop.getUpperBound(), 46 secondPloop.getUpperBound()) && 47 matchOperands(firstPloop.getStep(), secondPloop.getStep()); 48 } 49 50 /// Checks if the parallel loops have mixed access to the same buffers. Returns 51 /// `true` if the first parallel loop writes to the same indices that the second 52 /// loop reads. 53 static bool haveNoReadsAfterWriteExceptSameIndex( 54 ParallelOp firstPloop, ParallelOp secondPloop, 55 const BlockAndValueMapping &firstToSecondPloopIndices) { 56 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores; 57 firstPloop.getBody()->walk([&](memref::StoreOp store) { 58 bufferStores[store.getMemRef()].push_back(store.getIndices()); 59 }); 60 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { 61 // Stop if the memref is defined in secondPloop body. Careful alias analysis 62 // is needed. 63 auto *memrefDef = load.getMemRef().getDefiningOp(); 64 if (memrefDef && memrefDef->getBlock() == load->getBlock()) 65 return WalkResult::interrupt(); 66 67 auto write = bufferStores.find(load.getMemRef()); 68 if (write == bufferStores.end()) 69 return WalkResult::advance(); 70 71 // Allow only single write access per buffer. 72 if (write->second.size() != 1) 73 return WalkResult::interrupt(); 74 75 // Check that the load indices of secondPloop coincide with store indices of 76 // firstPloop for the same memrefs. 77 auto storeIndices = write->second.front(); 78 auto loadIndices = load.getIndices(); 79 if (storeIndices.size() != loadIndices.size()) 80 return WalkResult::interrupt(); 81 for (int i = 0, e = storeIndices.size(); i < e; ++i) { 82 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != 83 loadIndices[i]) 84 return WalkResult::interrupt(); 85 } 86 return WalkResult::advance(); 87 }); 88 return !walkResult.wasInterrupted(); 89 } 90 91 /// Analyzes dependencies in the most primitive way by checking simple read and 92 /// write patterns. 93 static LogicalResult 94 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, 95 const BlockAndValueMapping &firstToSecondPloopIndices) { 96 if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop, 97 firstToSecondPloopIndices)) 98 return failure(); 99 100 BlockAndValueMapping secondToFirstPloopIndices; 101 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), 102 firstPloop.getBody()->getArguments()); 103 return success(haveNoReadsAfterWriteExceptSameIndex( 104 secondPloop, firstPloop, secondToFirstPloopIndices)); 105 } 106 107 static bool 108 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, 109 const BlockAndValueMapping &firstToSecondPloopIndices) { 110 return !hasNestedParallelOp(firstPloop) && 111 !hasNestedParallelOp(secondPloop) && 112 equalIterationSpaces(firstPloop, secondPloop) && 113 succeeded(verifyDependencies(firstPloop, secondPloop, 114 firstToSecondPloopIndices)); 115 } 116 117 /// Prepends operations of firstPloop's body into secondPloop's body. 118 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, 119 OpBuilder b) { 120 BlockAndValueMapping firstToSecondPloopIndices; 121 firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), 122 secondPloop.getBody()->getArguments()); 123 124 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices)) 125 return; 126 127 b.setInsertionPointToStart(secondPloop.getBody()); 128 for (auto &op : firstPloop.getBody()->without_terminator()) 129 b.clone(op, firstToSecondPloopIndices); 130 firstPloop.erase(); 131 } 132 133 void mlir::scf::naivelyFuseParallelOps(Region ®ion) { 134 OpBuilder b(region); 135 // Consider every single block and attempt to fuse adjacent loops. 136 for (auto &block : region) { 137 SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}}; 138 // Not using `walk()` to traverse only top-level parallel loops and also 139 // make sure that there are no side-effecting ops between the parallel 140 // loops. 141 bool noSideEffects = true; 142 for (auto &op : block) { 143 if (auto ploop = dyn_cast<ParallelOp>(op)) { 144 if (noSideEffects) { 145 ploopChains.back().push_back(ploop); 146 } else { 147 ploopChains.push_back({ploop}); 148 noSideEffects = true; 149 } 150 continue; 151 } 152 // TODO: Handle region side effects properly. 153 noSideEffects &= 154 MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0; 155 } 156 for (ArrayRef<ParallelOp> ploops : ploopChains) { 157 for (int i = 0, e = ploops.size(); i + 1 < e; ++i) 158 fuseIfLegal(ploops[i], ploops[i + 1], b); 159 } 160 } 161 } 162 163 namespace { 164 struct ParallelLoopFusion 165 : public SCFParallelLoopFusionBase<ParallelLoopFusion> { 166 void runOnOperation() override { 167 getOperation()->walk([&](Operation *child) { 168 for (Region ®ion : child->getRegions()) 169 naivelyFuseParallelOps(region); 170 }); 171 } 172 }; 173 } // namespace 174 175 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() { 176 return std::make_unique<ParallelLoopFusion>(); 177 } 178