xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (revision b3a2208c566c475f7d1b6d40c67aec100ae29103)
1c25b20c0SAlex Zinenko //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2c25b20c0SAlex Zinenko //
3c25b20c0SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c25b20c0SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5c25b20c0SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c25b20c0SAlex Zinenko //
7c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
8c25b20c0SAlex Zinenko //
9c25b20c0SAlex Zinenko // This file implements loop fusion on parallel loops.
10c25b20c0SAlex Zinenko //
11c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
12c25b20c0SAlex Zinenko 
1367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h"
1467d0d7acSMichele Scuttari 
15c0d2ea9dSIvan Butygin #include "mlir/Analysis/AliasAnalysis.h"
16e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h"
178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
19c25b20c0SAlex Zinenko #include "mlir/IR/Builders.h"
204d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
21c25b20c0SAlex Zinenko #include "mlir/IR/OpDefinition.h"
22c3eb2978SHsiangkai Wang #include "mlir/IR/OperationSupport.h"
23fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h"
24c25b20c0SAlex Zinenko 
2567d0d7acSMichele Scuttari namespace mlir {
2667d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
2767d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
2867d0d7acSMichele Scuttari } // namespace mlir
2967d0d7acSMichele Scuttari 
30c25b20c0SAlex Zinenko using namespace mlir;
31c25b20c0SAlex Zinenko using namespace mlir::scf;
32c25b20c0SAlex Zinenko 
33c25b20c0SAlex Zinenko /// Verify there are no nested ParallelOps.
34c25b20c0SAlex Zinenko static bool hasNestedParallelOp(ParallelOp ploop) {
35c25b20c0SAlex Zinenko   auto walkResult =
36c25b20c0SAlex Zinenko       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
37c25b20c0SAlex Zinenko   return walkResult.wasInterrupted();
38c25b20c0SAlex Zinenko }
39c25b20c0SAlex Zinenko 
4097a2bd84SAlexander Belyaev /// Verify equal iteration spaces.
4197a2bd84SAlexander Belyaev static bool equalIterationSpaces(ParallelOp firstPloop,
4297a2bd84SAlexander Belyaev                                  ParallelOp secondPloop) {
4397a2bd84SAlexander Belyaev   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
4497a2bd84SAlexander Belyaev     return false;
4597a2bd84SAlexander Belyaev 
4697a2bd84SAlexander Belyaev   auto matchOperands = [&](const OperandRange &lhs,
4797a2bd84SAlexander Belyaev                            const OperandRange &rhs) -> bool {
4897a2bd84SAlexander Belyaev     // TODO: Extend this to support aliases and equal constants.
4997a2bd84SAlexander Belyaev     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
5097a2bd84SAlexander Belyaev   };
5197a2bd84SAlexander Belyaev   return matchOperands(firstPloop.getLowerBound(),
5297a2bd84SAlexander Belyaev                        secondPloop.getLowerBound()) &&
5397a2bd84SAlexander Belyaev          matchOperands(firstPloop.getUpperBound(),
5497a2bd84SAlexander Belyaev                        secondPloop.getUpperBound()) &&
5597a2bd84SAlexander Belyaev          matchOperands(firstPloop.getStep(), secondPloop.getStep());
5697a2bd84SAlexander Belyaev }
5797a2bd84SAlexander Belyaev 
58c25b20c0SAlex Zinenko /// Checks if the parallel loops have mixed access to the same buffers. Returns
59c25b20c0SAlex Zinenko /// `true` if the first parallel loop writes to the same indices that the second
60c25b20c0SAlex Zinenko /// loop reads.
61c25b20c0SAlex Zinenko static bool haveNoReadsAfterWriteExceptSameIndex(
62c25b20c0SAlex Zinenko     ParallelOp firstPloop, ParallelOp secondPloop,
63c0d2ea9dSIvan Butygin     const IRMapping &firstToSecondPloopIndices,
64c0d2ea9dSIvan Butygin     llvm::function_ref<bool(Value, Value)> mayAlias) {
65c25b20c0SAlex Zinenko   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
66c0d2ea9dSIvan Butygin   SmallVector<Value> bufferStoresVec;
67e2310704SJulian Gross   firstPloop.getBody()->walk([&](memref::StoreOp store) {
68136d746eSJacques Pienaar     bufferStores[store.getMemRef()].push_back(store.getIndices());
69c0d2ea9dSIvan Butygin     bufferStoresVec.emplace_back(store.getMemRef());
70c25b20c0SAlex Zinenko   });
71e2310704SJulian Gross   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
72c0d2ea9dSIvan Butygin     Value loadMem = load.getMemRef();
73c25b20c0SAlex Zinenko     // Stop if the memref is defined in secondPloop body. Careful alias analysis
74c25b20c0SAlex Zinenko     // is needed.
75c0d2ea9dSIvan Butygin     auto *memrefDef = loadMem.getDefiningOp();
76c4a04059SChristian Sigg     if (memrefDef && memrefDef->getBlock() == load->getBlock())
77c25b20c0SAlex Zinenko       return WalkResult::interrupt();
78c25b20c0SAlex Zinenko 
79c0d2ea9dSIvan Butygin     for (Value store : bufferStoresVec)
80c0d2ea9dSIvan Butygin       if (store != loadMem && mayAlias(store, loadMem))
81c0d2ea9dSIvan Butygin         return WalkResult::interrupt();
82c0d2ea9dSIvan Butygin 
83c0d2ea9dSIvan Butygin     auto write = bufferStores.find(loadMem);
84c25b20c0SAlex Zinenko     if (write == bufferStores.end())
85c25b20c0SAlex Zinenko       return WalkResult::advance();
86c25b20c0SAlex Zinenko 
87d17b005eSfabrizio-indirli     // Check that at last one store was retrieved
88*b3a2208cSAdrian Kuegel     if (write->second.empty())
89c25b20c0SAlex Zinenko       return WalkResult::interrupt();
90c25b20c0SAlex Zinenko 
91d17b005eSfabrizio-indirli     auto storeIndices = write->second.front();
92d17b005eSfabrizio-indirli 
93d17b005eSfabrizio-indirli     // Multiple writes to the same memref are allowed only on the same indices
94d17b005eSfabrizio-indirli     for (const auto &othStoreIndices : write->second) {
95d17b005eSfabrizio-indirli       if (othStoreIndices != storeIndices)
96d17b005eSfabrizio-indirli         return WalkResult::interrupt();
97d17b005eSfabrizio-indirli     }
98d17b005eSfabrizio-indirli 
99c25b20c0SAlex Zinenko     // Check that the load indices of secondPloop coincide with store indices of
100c25b20c0SAlex Zinenko     // firstPloop for the same memrefs.
101136d746eSJacques Pienaar     auto loadIndices = load.getIndices();
102c25b20c0SAlex Zinenko     if (storeIndices.size() != loadIndices.size())
103c25b20c0SAlex Zinenko       return WalkResult::interrupt();
104c25b20c0SAlex Zinenko     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
105c25b20c0SAlex Zinenko       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
106c3eb2978SHsiangkai Wang           loadIndices[i]) {
107c3eb2978SHsiangkai Wang         auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108c3eb2978SHsiangkai Wang         auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109c3eb2978SHsiangkai Wang         if (storeIndexDefOp && loadIndexDefOp) {
110c3eb2978SHsiangkai Wang           if (!isMemoryEffectFree(storeIndexDefOp))
111c25b20c0SAlex Zinenko             return WalkResult::interrupt();
112c3eb2978SHsiangkai Wang           if (!isMemoryEffectFree(loadIndexDefOp))
113c3eb2978SHsiangkai Wang             return WalkResult::interrupt();
114c3eb2978SHsiangkai Wang           if (!OperationEquivalence::isEquivalentTo(
115c3eb2978SHsiangkai Wang                   storeIndexDefOp, loadIndexDefOp,
116c3eb2978SHsiangkai Wang                   [&](Value storeIndex, Value loadIndex) {
117c3eb2978SHsiangkai Wang                     if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
118c3eb2978SHsiangkai Wang                         firstToSecondPloopIndices.lookupOrDefault(loadIndex))
119c3eb2978SHsiangkai Wang                       return failure();
120c3eb2978SHsiangkai Wang                     else
121c3eb2978SHsiangkai Wang                       return success();
122c3eb2978SHsiangkai Wang                   },
123c3eb2978SHsiangkai Wang                   /*markEquivalent=*/nullptr,
124c3eb2978SHsiangkai Wang                   OperationEquivalence::Flags::IgnoreLocations)) {
125c3eb2978SHsiangkai Wang             return WalkResult::interrupt();
126c3eb2978SHsiangkai Wang           }
127c3eb2978SHsiangkai Wang         } else
128c3eb2978SHsiangkai Wang           return WalkResult::interrupt();
129c3eb2978SHsiangkai Wang       }
130c25b20c0SAlex Zinenko     }
131c25b20c0SAlex Zinenko     return WalkResult::advance();
132c25b20c0SAlex Zinenko   });
133c25b20c0SAlex Zinenko   return !walkResult.wasInterrupted();
134c25b20c0SAlex Zinenko }
135c25b20c0SAlex Zinenko 
136c25b20c0SAlex Zinenko /// Analyzes dependencies in the most primitive way by checking simple read and
137c25b20c0SAlex Zinenko /// write patterns.
138c25b20c0SAlex Zinenko static LogicalResult
139c25b20c0SAlex Zinenko verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
140c0d2ea9dSIvan Butygin                    const IRMapping &firstToSecondPloopIndices,
141c0d2ea9dSIvan Butygin                    llvm::function_ref<bool(Value, Value)> mayAlias) {
142c0d2ea9dSIvan Butygin   if (!haveNoReadsAfterWriteExceptSameIndex(
143c0d2ea9dSIvan Butygin           firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
144c25b20c0SAlex Zinenko     return failure();
145c25b20c0SAlex Zinenko 
1464d67b278SJeff Niu   IRMapping secondToFirstPloopIndices;
147c25b20c0SAlex Zinenko   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
148c25b20c0SAlex Zinenko                                 firstPloop.getBody()->getArguments());
149c25b20c0SAlex Zinenko   return success(haveNoReadsAfterWriteExceptSameIndex(
150c0d2ea9dSIvan Butygin       secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
151c25b20c0SAlex Zinenko }
152c25b20c0SAlex Zinenko 
1534d67b278SJeff Niu static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
154c0d2ea9dSIvan Butygin                           const IRMapping &firstToSecondPloopIndices,
155c0d2ea9dSIvan Butygin                           llvm::function_ref<bool(Value, Value)> mayAlias) {
156c25b20c0SAlex Zinenko   return !hasNestedParallelOp(firstPloop) &&
157c25b20c0SAlex Zinenko          !hasNestedParallelOp(secondPloop) &&
15897a2bd84SAlexander Belyaev          equalIterationSpaces(firstPloop, secondPloop) &&
159c25b20c0SAlex Zinenko          succeeded(verifyDependencies(firstPloop, secondPloop,
160c0d2ea9dSIvan Butygin                                       firstToSecondPloopIndices, mayAlias));
161c25b20c0SAlex Zinenko }
162c25b20c0SAlex Zinenko 
163c25b20c0SAlex Zinenko /// Prepends operations of firstPloop's body into secondPloop's body.
1646050cf28SIvan Butygin /// Updates secondPloop with new loop.
1656050cf28SIvan Butygin static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
1666050cf28SIvan Butygin                         OpBuilder builder,
167c0d2ea9dSIvan Butygin                         llvm::function_ref<bool(Value, Value)> mayAlias) {
1686050cf28SIvan Butygin   Block *block1 = firstPloop.getBody();
1696050cf28SIvan Butygin   Block *block2 = secondPloop.getBody();
1704d67b278SJeff Niu   IRMapping firstToSecondPloopIndices;
1716050cf28SIvan Butygin   firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
172c25b20c0SAlex Zinenko 
173c0d2ea9dSIvan Butygin   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
174c0d2ea9dSIvan Butygin                      mayAlias))
175c25b20c0SAlex Zinenko     return;
176c25b20c0SAlex Zinenko 
17797a2bd84SAlexander Belyaev   DominanceInfo dom;
17897a2bd84SAlexander Belyaev   // We are fusing first loop into second, make sure there are no users of the
17997a2bd84SAlexander Belyaev   // first loop results between loops.
18097a2bd84SAlexander Belyaev   for (Operation *user : firstPloop->getUsers())
18197a2bd84SAlexander Belyaev     if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
18297a2bd84SAlexander Belyaev       return;
18397a2bd84SAlexander Belyaev 
18497a2bd84SAlexander Belyaev   ValueRange inits1 = firstPloop.getInitVals();
18597a2bd84SAlexander Belyaev   ValueRange inits2 = secondPloop.getInitVals();
18697a2bd84SAlexander Belyaev 
18797a2bd84SAlexander Belyaev   SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
18897a2bd84SAlexander Belyaev   newInitVars.append(inits2.begin(), inits2.end());
18997a2bd84SAlexander Belyaev 
19097a2bd84SAlexander Belyaev   IRRewriter b(builder);
19197a2bd84SAlexander Belyaev   b.setInsertionPoint(secondPloop);
19297a2bd84SAlexander Belyaev   auto newSecondPloop = b.create<ParallelOp>(
19397a2bd84SAlexander Belyaev       secondPloop.getLoc(), secondPloop.getLowerBound(),
19497a2bd84SAlexander Belyaev       secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
19597a2bd84SAlexander Belyaev 
19697a2bd84SAlexander Belyaev   Block *newBlock = newSecondPloop.getBody();
19797a2bd84SAlexander Belyaev   auto term1 = cast<ReduceOp>(block1->getTerminator());
19897a2bd84SAlexander Belyaev   auto term2 = cast<ReduceOp>(block2->getTerminator());
19997a2bd84SAlexander Belyaev 
20097a2bd84SAlexander Belyaev   b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
20197a2bd84SAlexander Belyaev                       newBlock->getArguments());
20297a2bd84SAlexander Belyaev   b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
20397a2bd84SAlexander Belyaev                       newBlock->getArguments());
20497a2bd84SAlexander Belyaev 
20597a2bd84SAlexander Belyaev   ValueRange results = newSecondPloop.getResults();
20697a2bd84SAlexander Belyaev   if (!results.empty()) {
20797a2bd84SAlexander Belyaev     b.setInsertionPointToEnd(newBlock);
20897a2bd84SAlexander Belyaev 
20997a2bd84SAlexander Belyaev     ValueRange reduceArgs1 = term1.getOperands();
21097a2bd84SAlexander Belyaev     ValueRange reduceArgs2 = term2.getOperands();
21197a2bd84SAlexander Belyaev     SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
21297a2bd84SAlexander Belyaev     newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
21397a2bd84SAlexander Belyaev 
21497a2bd84SAlexander Belyaev     auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
21597a2bd84SAlexander Belyaev 
21697a2bd84SAlexander Belyaev     for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
21797a2bd84SAlexander Belyaev              term1.getReductions(), term2.getReductions()))) {
21897a2bd84SAlexander Belyaev       Block &oldRedBlock = reg.front();
21997a2bd84SAlexander Belyaev       Block &newRedBlock = newReduceOp.getReductions()[i].front();
22097a2bd84SAlexander Belyaev       b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
22197a2bd84SAlexander Belyaev                           newRedBlock.getArguments());
22297a2bd84SAlexander Belyaev     }
22397a2bd84SAlexander Belyaev 
22497a2bd84SAlexander Belyaev     firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
22597a2bd84SAlexander Belyaev     secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
22697a2bd84SAlexander Belyaev   }
22797a2bd84SAlexander Belyaev   term1->erase();
22897a2bd84SAlexander Belyaev   term2->erase();
22997a2bd84SAlexander Belyaev   firstPloop.erase();
23097a2bd84SAlexander Belyaev   secondPloop.erase();
23197a2bd84SAlexander Belyaev   secondPloop = newSecondPloop;
232c25b20c0SAlex Zinenko }
233c25b20c0SAlex Zinenko 
234c0d2ea9dSIvan Butygin void mlir::scf::naivelyFuseParallelOps(
235c0d2ea9dSIvan Butygin     Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
236c25b20c0SAlex Zinenko   OpBuilder b(region);
237c25b20c0SAlex Zinenko   // Consider every single block and attempt to fuse adjacent loops.
2386050cf28SIvan Butygin   SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
239c25b20c0SAlex Zinenko   for (auto &block : region) {
2406050cf28SIvan Butygin     ploopChains.clear();
2416050cf28SIvan Butygin     ploopChains.push_back({});
2426050cf28SIvan Butygin 
243c25b20c0SAlex Zinenko     // Not using `walk()` to traverse only top-level parallel loops and also
244c25b20c0SAlex Zinenko     // make sure that there are no side-effecting ops between the parallel
245c25b20c0SAlex Zinenko     // loops.
246c25b20c0SAlex Zinenko     bool noSideEffects = true;
247c25b20c0SAlex Zinenko     for (auto &op : block) {
248c25b20c0SAlex Zinenko       if (auto ploop = dyn_cast<ParallelOp>(op)) {
249c25b20c0SAlex Zinenko         if (noSideEffects) {
250c25b20c0SAlex Zinenko           ploopChains.back().push_back(ploop);
251c25b20c0SAlex Zinenko         } else {
252c25b20c0SAlex Zinenko           ploopChains.push_back({ploop});
253c25b20c0SAlex Zinenko           noSideEffects = true;
254c25b20c0SAlex Zinenko         }
255c25b20c0SAlex Zinenko         continue;
256c25b20c0SAlex Zinenko       }
257c25b20c0SAlex Zinenko       // TODO: Handle region side effects properly.
258fc367dfaSMahesh Ravishankar       noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
259c25b20c0SAlex Zinenko     }
2606050cf28SIvan Butygin     for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
261c25b20c0SAlex Zinenko       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
262c0d2ea9dSIvan Butygin         fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
263c25b20c0SAlex Zinenko     }
264c25b20c0SAlex Zinenko   }
265c25b20c0SAlex Zinenko }
266c25b20c0SAlex Zinenko 
267c25b20c0SAlex Zinenko namespace {
268039b969bSMichele Scuttari struct ParallelLoopFusion
26967d0d7acSMichele Scuttari     : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
270c25b20c0SAlex Zinenko   void runOnOperation() override {
271c0d2ea9dSIvan Butygin     auto &AA = getAnalysis<AliasAnalysis>();
272c0d2ea9dSIvan Butygin 
273c0d2ea9dSIvan Butygin     auto mayAlias = [&](Value val1, Value val2) -> bool {
274c0d2ea9dSIvan Butygin       return !AA.alias(val1, val2).isNo();
275c0d2ea9dSIvan Butygin     };
276c0d2ea9dSIvan Butygin 
277c25b20c0SAlex Zinenko     getOperation()->walk([&](Operation *child) {
278c25b20c0SAlex Zinenko       for (Region &region : child->getRegions())
279c0d2ea9dSIvan Butygin         naivelyFuseParallelOps(region, mayAlias);
280c25b20c0SAlex Zinenko     });
281c25b20c0SAlex Zinenko   }
282c25b20c0SAlex Zinenko };
283c25b20c0SAlex Zinenko } // namespace
284039b969bSMichele Scuttari 
285039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
286039b969bSMichele Scuttari   return std::make_unique<ParallelLoopFusion>();
287039b969bSMichele Scuttari }
288