1c25b20c0SAlex Zinenko //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// 2c25b20c0SAlex Zinenko // 3c25b20c0SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c25b20c0SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5c25b20c0SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c25b20c0SAlex Zinenko // 7c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===// 8c25b20c0SAlex Zinenko // 9c25b20c0SAlex Zinenko // This file implements loop fusion on parallel loops. 10c25b20c0SAlex Zinenko // 11c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===// 12c25b20c0SAlex Zinenko 1367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h" 1467d0d7acSMichele Scuttari 15c0d2ea9dSIvan Butygin #include "mlir/Analysis/AliasAnalysis.h" 16e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h" 178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h" 19c25b20c0SAlex Zinenko #include "mlir/IR/Builders.h" 204d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 21c25b20c0SAlex Zinenko #include "mlir/IR/OpDefinition.h" 22c3eb2978SHsiangkai Wang #include "mlir/IR/OperationSupport.h" 23fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h" 24c25b20c0SAlex Zinenko 2567d0d7acSMichele Scuttari namespace mlir { 2667d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION 2767d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 2867d0d7acSMichele Scuttari } // namespace mlir 2967d0d7acSMichele Scuttari 30c25b20c0SAlex Zinenko using namespace mlir; 31c25b20c0SAlex Zinenko using namespace mlir::scf; 32c25b20c0SAlex Zinenko 33c25b20c0SAlex Zinenko /// Verify there are no nested ParallelOps. 34c25b20c0SAlex Zinenko static bool hasNestedParallelOp(ParallelOp ploop) { 35c25b20c0SAlex Zinenko auto walkResult = 36c25b20c0SAlex Zinenko ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); 37c25b20c0SAlex Zinenko return walkResult.wasInterrupted(); 38c25b20c0SAlex Zinenko } 39c25b20c0SAlex Zinenko 4097a2bd84SAlexander Belyaev /// Verify equal iteration spaces. 4197a2bd84SAlexander Belyaev static bool equalIterationSpaces(ParallelOp firstPloop, 4297a2bd84SAlexander Belyaev ParallelOp secondPloop) { 4397a2bd84SAlexander Belyaev if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) 4497a2bd84SAlexander Belyaev return false; 4597a2bd84SAlexander Belyaev 4697a2bd84SAlexander Belyaev auto matchOperands = [&](const OperandRange &lhs, 4797a2bd84SAlexander Belyaev const OperandRange &rhs) -> bool { 4897a2bd84SAlexander Belyaev // TODO: Extend this to support aliases and equal constants. 4997a2bd84SAlexander Belyaev return std::equal(lhs.begin(), lhs.end(), rhs.begin()); 5097a2bd84SAlexander Belyaev }; 5197a2bd84SAlexander Belyaev return matchOperands(firstPloop.getLowerBound(), 5297a2bd84SAlexander Belyaev secondPloop.getLowerBound()) && 5397a2bd84SAlexander Belyaev matchOperands(firstPloop.getUpperBound(), 5497a2bd84SAlexander Belyaev secondPloop.getUpperBound()) && 5597a2bd84SAlexander Belyaev matchOperands(firstPloop.getStep(), secondPloop.getStep()); 5697a2bd84SAlexander Belyaev } 5797a2bd84SAlexander Belyaev 58c25b20c0SAlex Zinenko /// Checks if the parallel loops have mixed access to the same buffers. Returns 59c25b20c0SAlex Zinenko /// `true` if the first parallel loop writes to the same indices that the second 60c25b20c0SAlex Zinenko /// loop reads. 61c25b20c0SAlex Zinenko static bool haveNoReadsAfterWriteExceptSameIndex( 62c25b20c0SAlex Zinenko ParallelOp firstPloop, ParallelOp secondPloop, 63c0d2ea9dSIvan Butygin const IRMapping &firstToSecondPloopIndices, 64c0d2ea9dSIvan Butygin llvm::function_ref<bool(Value, Value)> mayAlias) { 65c25b20c0SAlex Zinenko DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores; 66c0d2ea9dSIvan Butygin SmallVector<Value> bufferStoresVec; 67e2310704SJulian Gross firstPloop.getBody()->walk([&](memref::StoreOp store) { 68136d746eSJacques Pienaar bufferStores[store.getMemRef()].push_back(store.getIndices()); 69c0d2ea9dSIvan Butygin bufferStoresVec.emplace_back(store.getMemRef()); 70c25b20c0SAlex Zinenko }); 71e2310704SJulian Gross auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { 72c0d2ea9dSIvan Butygin Value loadMem = load.getMemRef(); 73c25b20c0SAlex Zinenko // Stop if the memref is defined in secondPloop body. Careful alias analysis 74c25b20c0SAlex Zinenko // is needed. 75c0d2ea9dSIvan Butygin auto *memrefDef = loadMem.getDefiningOp(); 76c4a04059SChristian Sigg if (memrefDef && memrefDef->getBlock() == load->getBlock()) 77c25b20c0SAlex Zinenko return WalkResult::interrupt(); 78c25b20c0SAlex Zinenko 79c0d2ea9dSIvan Butygin for (Value store : bufferStoresVec) 80c0d2ea9dSIvan Butygin if (store != loadMem && mayAlias(store, loadMem)) 81c0d2ea9dSIvan Butygin return WalkResult::interrupt(); 82c0d2ea9dSIvan Butygin 83c0d2ea9dSIvan Butygin auto write = bufferStores.find(loadMem); 84c25b20c0SAlex Zinenko if (write == bufferStores.end()) 85c25b20c0SAlex Zinenko return WalkResult::advance(); 86c25b20c0SAlex Zinenko 87d17b005eSfabrizio-indirli // Check that at last one store was retrieved 88*b3a2208cSAdrian Kuegel if (write->second.empty()) 89c25b20c0SAlex Zinenko return WalkResult::interrupt(); 90c25b20c0SAlex Zinenko 91d17b005eSfabrizio-indirli auto storeIndices = write->second.front(); 92d17b005eSfabrizio-indirli 93d17b005eSfabrizio-indirli // Multiple writes to the same memref are allowed only on the same indices 94d17b005eSfabrizio-indirli for (const auto &othStoreIndices : write->second) { 95d17b005eSfabrizio-indirli if (othStoreIndices != storeIndices) 96d17b005eSfabrizio-indirli return WalkResult::interrupt(); 97d17b005eSfabrizio-indirli } 98d17b005eSfabrizio-indirli 99c25b20c0SAlex Zinenko // Check that the load indices of secondPloop coincide with store indices of 100c25b20c0SAlex Zinenko // firstPloop for the same memrefs. 101136d746eSJacques Pienaar auto loadIndices = load.getIndices(); 102c25b20c0SAlex Zinenko if (storeIndices.size() != loadIndices.size()) 103c25b20c0SAlex Zinenko return WalkResult::interrupt(); 104c25b20c0SAlex Zinenko for (int i = 0, e = storeIndices.size(); i < e; ++i) { 105c25b20c0SAlex Zinenko if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != 106c3eb2978SHsiangkai Wang loadIndices[i]) { 107c3eb2978SHsiangkai Wang auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); 108c3eb2978SHsiangkai Wang auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); 109c3eb2978SHsiangkai Wang if (storeIndexDefOp && loadIndexDefOp) { 110c3eb2978SHsiangkai Wang if (!isMemoryEffectFree(storeIndexDefOp)) 111c25b20c0SAlex Zinenko return WalkResult::interrupt(); 112c3eb2978SHsiangkai Wang if (!isMemoryEffectFree(loadIndexDefOp)) 113c3eb2978SHsiangkai Wang return WalkResult::interrupt(); 114c3eb2978SHsiangkai Wang if (!OperationEquivalence::isEquivalentTo( 115c3eb2978SHsiangkai Wang storeIndexDefOp, loadIndexDefOp, 116c3eb2978SHsiangkai Wang [&](Value storeIndex, Value loadIndex) { 117c3eb2978SHsiangkai Wang if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) != 118c3eb2978SHsiangkai Wang firstToSecondPloopIndices.lookupOrDefault(loadIndex)) 119c3eb2978SHsiangkai Wang return failure(); 120c3eb2978SHsiangkai Wang else 121c3eb2978SHsiangkai Wang return success(); 122c3eb2978SHsiangkai Wang }, 123c3eb2978SHsiangkai Wang /*markEquivalent=*/nullptr, 124c3eb2978SHsiangkai Wang OperationEquivalence::Flags::IgnoreLocations)) { 125c3eb2978SHsiangkai Wang return WalkResult::interrupt(); 126c3eb2978SHsiangkai Wang } 127c3eb2978SHsiangkai Wang } else 128c3eb2978SHsiangkai Wang return WalkResult::interrupt(); 129c3eb2978SHsiangkai Wang } 130c25b20c0SAlex Zinenko } 131c25b20c0SAlex Zinenko return WalkResult::advance(); 132c25b20c0SAlex Zinenko }); 133c25b20c0SAlex Zinenko return !walkResult.wasInterrupted(); 134c25b20c0SAlex Zinenko } 135c25b20c0SAlex Zinenko 136c25b20c0SAlex Zinenko /// Analyzes dependencies in the most primitive way by checking simple read and 137c25b20c0SAlex Zinenko /// write patterns. 138c25b20c0SAlex Zinenko static LogicalResult 139c25b20c0SAlex Zinenko verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, 140c0d2ea9dSIvan Butygin const IRMapping &firstToSecondPloopIndices, 141c0d2ea9dSIvan Butygin llvm::function_ref<bool(Value, Value)> mayAlias) { 142c0d2ea9dSIvan Butygin if (!haveNoReadsAfterWriteExceptSameIndex( 143c0d2ea9dSIvan Butygin firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) 144c25b20c0SAlex Zinenko return failure(); 145c25b20c0SAlex Zinenko 1464d67b278SJeff Niu IRMapping secondToFirstPloopIndices; 147c25b20c0SAlex Zinenko secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), 148c25b20c0SAlex Zinenko firstPloop.getBody()->getArguments()); 149c25b20c0SAlex Zinenko return success(haveNoReadsAfterWriteExceptSameIndex( 150c0d2ea9dSIvan Butygin secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); 151c25b20c0SAlex Zinenko } 152c25b20c0SAlex Zinenko 1534d67b278SJeff Niu static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, 154c0d2ea9dSIvan Butygin const IRMapping &firstToSecondPloopIndices, 155c0d2ea9dSIvan Butygin llvm::function_ref<bool(Value, Value)> mayAlias) { 156c25b20c0SAlex Zinenko return !hasNestedParallelOp(firstPloop) && 157c25b20c0SAlex Zinenko !hasNestedParallelOp(secondPloop) && 15897a2bd84SAlexander Belyaev equalIterationSpaces(firstPloop, secondPloop) && 159c25b20c0SAlex Zinenko succeeded(verifyDependencies(firstPloop, secondPloop, 160c0d2ea9dSIvan Butygin firstToSecondPloopIndices, mayAlias)); 161c25b20c0SAlex Zinenko } 162c25b20c0SAlex Zinenko 163c25b20c0SAlex Zinenko /// Prepends operations of firstPloop's body into secondPloop's body. 1646050cf28SIvan Butygin /// Updates secondPloop with new loop. 1656050cf28SIvan Butygin static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, 1666050cf28SIvan Butygin OpBuilder builder, 167c0d2ea9dSIvan Butygin llvm::function_ref<bool(Value, Value)> mayAlias) { 1686050cf28SIvan Butygin Block *block1 = firstPloop.getBody(); 1696050cf28SIvan Butygin Block *block2 = secondPloop.getBody(); 1704d67b278SJeff Niu IRMapping firstToSecondPloopIndices; 1716050cf28SIvan Butygin firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments()); 172c25b20c0SAlex Zinenko 173c0d2ea9dSIvan Butygin if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, 174c0d2ea9dSIvan Butygin mayAlias)) 175c25b20c0SAlex Zinenko return; 176c25b20c0SAlex Zinenko 17797a2bd84SAlexander Belyaev DominanceInfo dom; 17897a2bd84SAlexander Belyaev // We are fusing first loop into second, make sure there are no users of the 17997a2bd84SAlexander Belyaev // first loop results between loops. 18097a2bd84SAlexander Belyaev for (Operation *user : firstPloop->getUsers()) 18197a2bd84SAlexander Belyaev if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) 18297a2bd84SAlexander Belyaev return; 18397a2bd84SAlexander Belyaev 18497a2bd84SAlexander Belyaev ValueRange inits1 = firstPloop.getInitVals(); 18597a2bd84SAlexander Belyaev ValueRange inits2 = secondPloop.getInitVals(); 18697a2bd84SAlexander Belyaev 18797a2bd84SAlexander Belyaev SmallVector<Value> newInitVars(inits1.begin(), inits1.end()); 18897a2bd84SAlexander Belyaev newInitVars.append(inits2.begin(), inits2.end()); 18997a2bd84SAlexander Belyaev 19097a2bd84SAlexander Belyaev IRRewriter b(builder); 19197a2bd84SAlexander Belyaev b.setInsertionPoint(secondPloop); 19297a2bd84SAlexander Belyaev auto newSecondPloop = b.create<ParallelOp>( 19397a2bd84SAlexander Belyaev secondPloop.getLoc(), secondPloop.getLowerBound(), 19497a2bd84SAlexander Belyaev secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); 19597a2bd84SAlexander Belyaev 19697a2bd84SAlexander Belyaev Block *newBlock = newSecondPloop.getBody(); 19797a2bd84SAlexander Belyaev auto term1 = cast<ReduceOp>(block1->getTerminator()); 19897a2bd84SAlexander Belyaev auto term2 = cast<ReduceOp>(block2->getTerminator()); 19997a2bd84SAlexander Belyaev 20097a2bd84SAlexander Belyaev b.inlineBlockBefore(block2, newBlock, newBlock->begin(), 20197a2bd84SAlexander Belyaev newBlock->getArguments()); 20297a2bd84SAlexander Belyaev b.inlineBlockBefore(block1, newBlock, newBlock->begin(), 20397a2bd84SAlexander Belyaev newBlock->getArguments()); 20497a2bd84SAlexander Belyaev 20597a2bd84SAlexander Belyaev ValueRange results = newSecondPloop.getResults(); 20697a2bd84SAlexander Belyaev if (!results.empty()) { 20797a2bd84SAlexander Belyaev b.setInsertionPointToEnd(newBlock); 20897a2bd84SAlexander Belyaev 20997a2bd84SAlexander Belyaev ValueRange reduceArgs1 = term1.getOperands(); 21097a2bd84SAlexander Belyaev ValueRange reduceArgs2 = term2.getOperands(); 21197a2bd84SAlexander Belyaev SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); 21297a2bd84SAlexander Belyaev newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); 21397a2bd84SAlexander Belyaev 21497a2bd84SAlexander Belyaev auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs); 21597a2bd84SAlexander Belyaev 21697a2bd84SAlexander Belyaev for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>( 21797a2bd84SAlexander Belyaev term1.getReductions(), term2.getReductions()))) { 21897a2bd84SAlexander Belyaev Block &oldRedBlock = reg.front(); 21997a2bd84SAlexander Belyaev Block &newRedBlock = newReduceOp.getReductions()[i].front(); 22097a2bd84SAlexander Belyaev b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), 22197a2bd84SAlexander Belyaev newRedBlock.getArguments()); 22297a2bd84SAlexander Belyaev } 22397a2bd84SAlexander Belyaev 22497a2bd84SAlexander Belyaev firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); 22597a2bd84SAlexander Belyaev secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); 22697a2bd84SAlexander Belyaev } 22797a2bd84SAlexander Belyaev term1->erase(); 22897a2bd84SAlexander Belyaev term2->erase(); 22997a2bd84SAlexander Belyaev firstPloop.erase(); 23097a2bd84SAlexander Belyaev secondPloop.erase(); 23197a2bd84SAlexander Belyaev secondPloop = newSecondPloop; 232c25b20c0SAlex Zinenko } 233c25b20c0SAlex Zinenko 234c0d2ea9dSIvan Butygin void mlir::scf::naivelyFuseParallelOps( 235c0d2ea9dSIvan Butygin Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) { 236c25b20c0SAlex Zinenko OpBuilder b(region); 237c25b20c0SAlex Zinenko // Consider every single block and attempt to fuse adjacent loops. 2386050cf28SIvan Butygin SmallVector<SmallVector<ParallelOp>, 1> ploopChains; 239c25b20c0SAlex Zinenko for (auto &block : region) { 2406050cf28SIvan Butygin ploopChains.clear(); 2416050cf28SIvan Butygin ploopChains.push_back({}); 2426050cf28SIvan Butygin 243c25b20c0SAlex Zinenko // Not using `walk()` to traverse only top-level parallel loops and also 244c25b20c0SAlex Zinenko // make sure that there are no side-effecting ops between the parallel 245c25b20c0SAlex Zinenko // loops. 246c25b20c0SAlex Zinenko bool noSideEffects = true; 247c25b20c0SAlex Zinenko for (auto &op : block) { 248c25b20c0SAlex Zinenko if (auto ploop = dyn_cast<ParallelOp>(op)) { 249c25b20c0SAlex Zinenko if (noSideEffects) { 250c25b20c0SAlex Zinenko ploopChains.back().push_back(ploop); 251c25b20c0SAlex Zinenko } else { 252c25b20c0SAlex Zinenko ploopChains.push_back({ploop}); 253c25b20c0SAlex Zinenko noSideEffects = true; 254c25b20c0SAlex Zinenko } 255c25b20c0SAlex Zinenko continue; 256c25b20c0SAlex Zinenko } 257c25b20c0SAlex Zinenko // TODO: Handle region side effects properly. 258fc367dfaSMahesh Ravishankar noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0; 259c25b20c0SAlex Zinenko } 2606050cf28SIvan Butygin for (MutableArrayRef<ParallelOp> ploops : ploopChains) { 261c25b20c0SAlex Zinenko for (int i = 0, e = ploops.size(); i + 1 < e; ++i) 262c0d2ea9dSIvan Butygin fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); 263c25b20c0SAlex Zinenko } 264c25b20c0SAlex Zinenko } 265c25b20c0SAlex Zinenko } 266c25b20c0SAlex Zinenko 267c25b20c0SAlex Zinenko namespace { 268039b969bSMichele Scuttari struct ParallelLoopFusion 26967d0d7acSMichele Scuttari : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> { 270c25b20c0SAlex Zinenko void runOnOperation() override { 271c0d2ea9dSIvan Butygin auto &AA = getAnalysis<AliasAnalysis>(); 272c0d2ea9dSIvan Butygin 273c0d2ea9dSIvan Butygin auto mayAlias = [&](Value val1, Value val2) -> bool { 274c0d2ea9dSIvan Butygin return !AA.alias(val1, val2).isNo(); 275c0d2ea9dSIvan Butygin }; 276c0d2ea9dSIvan Butygin 277c25b20c0SAlex Zinenko getOperation()->walk([&](Operation *child) { 278c25b20c0SAlex Zinenko for (Region ®ion : child->getRegions()) 279c0d2ea9dSIvan Butygin naivelyFuseParallelOps(region, mayAlias); 280c25b20c0SAlex Zinenko }); 281c25b20c0SAlex Zinenko } 282c25b20c0SAlex Zinenko }; 283c25b20c0SAlex Zinenko } // namespace 284039b969bSMichele Scuttari 285039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() { 286039b969bSMichele Scuttari return std::make_unique<ParallelLoopFusion>(); 287039b969bSMichele Scuttari } 288