//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements loop fusion on parallel loops. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Analysis/AliasAnalysis.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OperationSupport.h" #include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::scf; /// Verify there are no nested ParallelOps. static bool hasNestedParallelOp(ParallelOp ploop) { auto walkResult = ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); return walkResult.wasInterrupted(); } /// Verify equal iteration spaces. static bool equalIterationSpaces(ParallelOp firstPloop, ParallelOp secondPloop) { if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) return false; auto matchOperands = [&](const OperandRange &lhs, const OperandRange &rhs) -> bool { // TODO: Extend this to support aliases and equal constants. return std::equal(lhs.begin(), lhs.end(), rhs.begin()); }; return matchOperands(firstPloop.getLowerBound(), secondPloop.getLowerBound()) && matchOperands(firstPloop.getUpperBound(), secondPloop.getUpperBound()) && matchOperands(firstPloop.getStep(), secondPloop.getStep()); } /// Checks if the parallel loops have mixed access to the same buffers. Returns /// `true` if the first parallel loop writes to the same indices that the second /// loop reads. static bool haveNoReadsAfterWriteExceptSameIndex( ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref mayAlias) { DenseMap> bufferStores; SmallVector bufferStoresVec; firstPloop.getBody()->walk([&](memref::StoreOp store) { bufferStores[store.getMemRef()].push_back(store.getIndices()); bufferStoresVec.emplace_back(store.getMemRef()); }); auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { Value loadMem = load.getMemRef(); // Stop if the memref is defined in secondPloop body. Careful alias analysis // is needed. auto *memrefDef = loadMem.getDefiningOp(); if (memrefDef && memrefDef->getBlock() == load->getBlock()) return WalkResult::interrupt(); for (Value store : bufferStoresVec) if (store != loadMem && mayAlias(store, loadMem)) return WalkResult::interrupt(); auto write = bufferStores.find(loadMem); if (write == bufferStores.end()) return WalkResult::advance(); // Check that at last one store was retrieved if (write->second.empty()) return WalkResult::interrupt(); auto storeIndices = write->second.front(); // Multiple writes to the same memref are allowed only on the same indices for (const auto &othStoreIndices : write->second) { if (othStoreIndices != storeIndices) return WalkResult::interrupt(); } // Check that the load indices of secondPloop coincide with store indices of // firstPloop for the same memrefs. auto loadIndices = load.getIndices(); if (storeIndices.size() != loadIndices.size()) return WalkResult::interrupt(); for (int i = 0, e = storeIndices.size(); i < e; ++i) { if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != loadIndices[i]) { auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); if (storeIndexDefOp && loadIndexDefOp) { if (!isMemoryEffectFree(storeIndexDefOp)) return WalkResult::interrupt(); if (!isMemoryEffectFree(loadIndexDefOp)) return WalkResult::interrupt(); if (!OperationEquivalence::isEquivalentTo( storeIndexDefOp, loadIndexDefOp, [&](Value storeIndex, Value loadIndex) { if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) != firstToSecondPloopIndices.lookupOrDefault(loadIndex)) return failure(); else return success(); }, /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations)) { return WalkResult::interrupt(); } } else return WalkResult::interrupt(); } } return WalkResult::advance(); }); return !walkResult.wasInterrupted(); } /// Analyzes dependencies in the most primitive way by checking simple read and /// write patterns. static LogicalResult verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref mayAlias) { if (!haveNoReadsAfterWriteExceptSameIndex( firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) return failure(); IRMapping secondToFirstPloopIndices; secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), firstPloop.getBody()->getArguments()); return success(haveNoReadsAfterWriteExceptSameIndex( secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); } static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref mayAlias) { return !hasNestedParallelOp(firstPloop) && !hasNestedParallelOp(secondPloop) && equalIterationSpaces(firstPloop, secondPloop) && succeeded(verifyDependencies(firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)); } /// Prepends operations of firstPloop's body into secondPloop's body. /// Updates secondPloop with new loop. static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, OpBuilder builder, llvm::function_ref mayAlias) { Block *block1 = firstPloop.getBody(); Block *block2 = secondPloop.getBody(); IRMapping firstToSecondPloopIndices; firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments()); if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) return; DominanceInfo dom; // We are fusing first loop into second, make sure there are no users of the // first loop results between loops. for (Operation *user : firstPloop->getUsers()) if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) return; ValueRange inits1 = firstPloop.getInitVals(); ValueRange inits2 = secondPloop.getInitVals(); SmallVector newInitVars(inits1.begin(), inits1.end()); newInitVars.append(inits2.begin(), inits2.end()); IRRewriter b(builder); b.setInsertionPoint(secondPloop); auto newSecondPloop = b.create( secondPloop.getLoc(), secondPloop.getLowerBound(), secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); Block *newBlock = newSecondPloop.getBody(); auto term1 = cast(block1->getTerminator()); auto term2 = cast(block2->getTerminator()); b.inlineBlockBefore(block2, newBlock, newBlock->begin(), newBlock->getArguments()); b.inlineBlockBefore(block1, newBlock, newBlock->begin(), newBlock->getArguments()); ValueRange results = newSecondPloop.getResults(); if (!results.empty()) { b.setInsertionPointToEnd(newBlock); ValueRange reduceArgs1 = term1.getOperands(); ValueRange reduceArgs2 = term2.getOperands(); SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); auto newReduceOp = b.create(term2.getLoc(), newReduceArgs); for (auto &&[i, reg] : llvm::enumerate(llvm::concat( term1.getReductions(), term2.getReductions()))) { Block &oldRedBlock = reg.front(); Block &newRedBlock = newReduceOp.getReductions()[i].front(); b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), newRedBlock.getArguments()); } firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); } term1->erase(); term2->erase(); firstPloop.erase(); secondPloop.erase(); secondPloop = newSecondPloop; } void mlir::scf::naivelyFuseParallelOps( Region ®ion, llvm::function_ref mayAlias) { OpBuilder b(region); // Consider every single block and attempt to fuse adjacent loops. SmallVector, 1> ploopChains; for (auto &block : region) { ploopChains.clear(); ploopChains.push_back({}); // Not using `walk()` to traverse only top-level parallel loops and also // make sure that there are no side-effecting ops between the parallel // loops. bool noSideEffects = true; for (auto &op : block) { if (auto ploop = dyn_cast(op)) { if (noSideEffects) { ploopChains.back().push_back(ploop); } else { ploopChains.push_back({ploop}); noSideEffects = true; } continue; } // TODO: Handle region side effects properly. noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0; } for (MutableArrayRef ploops : ploopChains) { for (int i = 0, e = ploops.size(); i + 1 < e; ++i) fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); } } } namespace { struct ParallelLoopFusion : public impl::SCFParallelLoopFusionBase { void runOnOperation() override { auto &AA = getAnalysis(); auto mayAlias = [&](Value val1, Value val2) -> bool { return !AA.alias(val1, val2).isNo(); }; getOperation()->walk([&](Operation *child) { for (Region ®ion : child->getRegions()) naivelyFuseParallelOps(region, mayAlias); }); } }; } // namespace std::unique_ptr mlir::createParallelLoopFusionPass() { return std::make_unique(); }