xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (revision b3a2208c566c475f7d1b6d40c67aec100ae29103)
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Analysis/AliasAnalysis.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/OperationSupport.h"
23 #include "mlir/Interfaces/SideEffectInterfaces.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
27 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::scf;
32 
33 /// Verify there are no nested ParallelOps.
34 static bool hasNestedParallelOp(ParallelOp ploop) {
35   auto walkResult =
36       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
37   return walkResult.wasInterrupted();
38 }
39 
40 /// Verify equal iteration spaces.
41 static bool equalIterationSpaces(ParallelOp firstPloop,
42                                  ParallelOp secondPloop) {
43   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
44     return false;
45 
46   auto matchOperands = [&](const OperandRange &lhs,
47                            const OperandRange &rhs) -> bool {
48     // TODO: Extend this to support aliases and equal constants.
49     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
50   };
51   return matchOperands(firstPloop.getLowerBound(),
52                        secondPloop.getLowerBound()) &&
53          matchOperands(firstPloop.getUpperBound(),
54                        secondPloop.getUpperBound()) &&
55          matchOperands(firstPloop.getStep(), secondPloop.getStep());
56 }
57 
58 /// Checks if the parallel loops have mixed access to the same buffers. Returns
59 /// `true` if the first parallel loop writes to the same indices that the second
60 /// loop reads.
61 static bool haveNoReadsAfterWriteExceptSameIndex(
62     ParallelOp firstPloop, ParallelOp secondPloop,
63     const IRMapping &firstToSecondPloopIndices,
64     llvm::function_ref<bool(Value, Value)> mayAlias) {
65   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
66   SmallVector<Value> bufferStoresVec;
67   firstPloop.getBody()->walk([&](memref::StoreOp store) {
68     bufferStores[store.getMemRef()].push_back(store.getIndices());
69     bufferStoresVec.emplace_back(store.getMemRef());
70   });
71   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
72     Value loadMem = load.getMemRef();
73     // Stop if the memref is defined in secondPloop body. Careful alias analysis
74     // is needed.
75     auto *memrefDef = loadMem.getDefiningOp();
76     if (memrefDef && memrefDef->getBlock() == load->getBlock())
77       return WalkResult::interrupt();
78 
79     for (Value store : bufferStoresVec)
80       if (store != loadMem && mayAlias(store, loadMem))
81         return WalkResult::interrupt();
82 
83     auto write = bufferStores.find(loadMem);
84     if (write == bufferStores.end())
85       return WalkResult::advance();
86 
87     // Check that at last one store was retrieved
88     if (write->second.empty())
89       return WalkResult::interrupt();
90 
91     auto storeIndices = write->second.front();
92 
93     // Multiple writes to the same memref are allowed only on the same indices
94     for (const auto &othStoreIndices : write->second) {
95       if (othStoreIndices != storeIndices)
96         return WalkResult::interrupt();
97     }
98 
99     // Check that the load indices of secondPloop coincide with store indices of
100     // firstPloop for the same memrefs.
101     auto loadIndices = load.getIndices();
102     if (storeIndices.size() != loadIndices.size())
103       return WalkResult::interrupt();
104     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
105       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
106           loadIndices[i]) {
107         auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108         auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109         if (storeIndexDefOp && loadIndexDefOp) {
110           if (!isMemoryEffectFree(storeIndexDefOp))
111             return WalkResult::interrupt();
112           if (!isMemoryEffectFree(loadIndexDefOp))
113             return WalkResult::interrupt();
114           if (!OperationEquivalence::isEquivalentTo(
115                   storeIndexDefOp, loadIndexDefOp,
116                   [&](Value storeIndex, Value loadIndex) {
117                     if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
118                         firstToSecondPloopIndices.lookupOrDefault(loadIndex))
119                       return failure();
120                     else
121                       return success();
122                   },
123                   /*markEquivalent=*/nullptr,
124                   OperationEquivalence::Flags::IgnoreLocations)) {
125             return WalkResult::interrupt();
126           }
127         } else
128           return WalkResult::interrupt();
129       }
130     }
131     return WalkResult::advance();
132   });
133   return !walkResult.wasInterrupted();
134 }
135 
136 /// Analyzes dependencies in the most primitive way by checking simple read and
137 /// write patterns.
138 static LogicalResult
139 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
140                    const IRMapping &firstToSecondPloopIndices,
141                    llvm::function_ref<bool(Value, Value)> mayAlias) {
142   if (!haveNoReadsAfterWriteExceptSameIndex(
143           firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
144     return failure();
145 
146   IRMapping secondToFirstPloopIndices;
147   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
148                                 firstPloop.getBody()->getArguments());
149   return success(haveNoReadsAfterWriteExceptSameIndex(
150       secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
151 }
152 
153 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
154                           const IRMapping &firstToSecondPloopIndices,
155                           llvm::function_ref<bool(Value, Value)> mayAlias) {
156   return !hasNestedParallelOp(firstPloop) &&
157          !hasNestedParallelOp(secondPloop) &&
158          equalIterationSpaces(firstPloop, secondPloop) &&
159          succeeded(verifyDependencies(firstPloop, secondPloop,
160                                       firstToSecondPloopIndices, mayAlias));
161 }
162 
163 /// Prepends operations of firstPloop's body into secondPloop's body.
164 /// Updates secondPloop with new loop.
165 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
166                         OpBuilder builder,
167                         llvm::function_ref<bool(Value, Value)> mayAlias) {
168   Block *block1 = firstPloop.getBody();
169   Block *block2 = secondPloop.getBody();
170   IRMapping firstToSecondPloopIndices;
171   firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
172 
173   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
174                      mayAlias))
175     return;
176 
177   DominanceInfo dom;
178   // We are fusing first loop into second, make sure there are no users of the
179   // first loop results between loops.
180   for (Operation *user : firstPloop->getUsers())
181     if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
182       return;
183 
184   ValueRange inits1 = firstPloop.getInitVals();
185   ValueRange inits2 = secondPloop.getInitVals();
186 
187   SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
188   newInitVars.append(inits2.begin(), inits2.end());
189 
190   IRRewriter b(builder);
191   b.setInsertionPoint(secondPloop);
192   auto newSecondPloop = b.create<ParallelOp>(
193       secondPloop.getLoc(), secondPloop.getLowerBound(),
194       secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
195 
196   Block *newBlock = newSecondPloop.getBody();
197   auto term1 = cast<ReduceOp>(block1->getTerminator());
198   auto term2 = cast<ReduceOp>(block2->getTerminator());
199 
200   b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
201                       newBlock->getArguments());
202   b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
203                       newBlock->getArguments());
204 
205   ValueRange results = newSecondPloop.getResults();
206   if (!results.empty()) {
207     b.setInsertionPointToEnd(newBlock);
208 
209     ValueRange reduceArgs1 = term1.getOperands();
210     ValueRange reduceArgs2 = term2.getOperands();
211     SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
212     newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
213 
214     auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
215 
216     for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
217              term1.getReductions(), term2.getReductions()))) {
218       Block &oldRedBlock = reg.front();
219       Block &newRedBlock = newReduceOp.getReductions()[i].front();
220       b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
221                           newRedBlock.getArguments());
222     }
223 
224     firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
225     secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
226   }
227   term1->erase();
228   term2->erase();
229   firstPloop.erase();
230   secondPloop.erase();
231   secondPloop = newSecondPloop;
232 }
233 
234 void mlir::scf::naivelyFuseParallelOps(
235     Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
236   OpBuilder b(region);
237   // Consider every single block and attempt to fuse adjacent loops.
238   SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
239   for (auto &block : region) {
240     ploopChains.clear();
241     ploopChains.push_back({});
242 
243     // Not using `walk()` to traverse only top-level parallel loops and also
244     // make sure that there are no side-effecting ops between the parallel
245     // loops.
246     bool noSideEffects = true;
247     for (auto &op : block) {
248       if (auto ploop = dyn_cast<ParallelOp>(op)) {
249         if (noSideEffects) {
250           ploopChains.back().push_back(ploop);
251         } else {
252           ploopChains.push_back({ploop});
253           noSideEffects = true;
254         }
255         continue;
256       }
257       // TODO: Handle region side effects properly.
258       noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
259     }
260     for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
261       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
262         fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
263     }
264   }
265 }
266 
267 namespace {
268 struct ParallelLoopFusion
269     : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
270   void runOnOperation() override {
271     auto &AA = getAnalysis<AliasAnalysis>();
272 
273     auto mayAlias = [&](Value val1, Value val2) -> bool {
274       return !AA.alias(val1, val2).isNo();
275     };
276 
277     getOperation()->walk([&](Operation *child) {
278       for (Region &region : child->getRegions())
279         naivelyFuseParallelOps(region, mayAlias);
280     });
281   }
282 };
283 } // namespace
284 
285 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
286   return std::make_unique<ParallelLoopFusion>();
287 }
288