xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (revision c3eb2978a60b4e2e0cf9c8a8f9c51b48bd49477a)
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Analysis/AliasAnalysis.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/OperationSupport.h"
23 #include "mlir/Interfaces/SideEffectInterfaces.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
27 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::scf;
32 
33 /// Verify there are no nested ParallelOps.
34 static bool hasNestedParallelOp(ParallelOp ploop) {
35   auto walkResult =
36       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
37   return walkResult.wasInterrupted();
38 }
39 
40 /// Verify equal iteration spaces.
41 static bool equalIterationSpaces(ParallelOp firstPloop,
42                                  ParallelOp secondPloop) {
43   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
44     return false;
45 
46   auto matchOperands = [&](const OperandRange &lhs,
47                            const OperandRange &rhs) -> bool {
48     // TODO: Extend this to support aliases and equal constants.
49     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
50   };
51   return matchOperands(firstPloop.getLowerBound(),
52                        secondPloop.getLowerBound()) &&
53          matchOperands(firstPloop.getUpperBound(),
54                        secondPloop.getUpperBound()) &&
55          matchOperands(firstPloop.getStep(), secondPloop.getStep());
56 }
57 
58 /// Checks if the parallel loops have mixed access to the same buffers. Returns
59 /// `true` if the first parallel loop writes to the same indices that the second
60 /// loop reads.
61 static bool haveNoReadsAfterWriteExceptSameIndex(
62     ParallelOp firstPloop, ParallelOp secondPloop,
63     const IRMapping &firstToSecondPloopIndices,
64     llvm::function_ref<bool(Value, Value)> mayAlias) {
65   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
66   SmallVector<Value> bufferStoresVec;
67   firstPloop.getBody()->walk([&](memref::StoreOp store) {
68     bufferStores[store.getMemRef()].push_back(store.getIndices());
69     bufferStoresVec.emplace_back(store.getMemRef());
70   });
71   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
72     Value loadMem = load.getMemRef();
73     // Stop if the memref is defined in secondPloop body. Careful alias analysis
74     // is needed.
75     auto *memrefDef = loadMem.getDefiningOp();
76     if (memrefDef && memrefDef->getBlock() == load->getBlock())
77       return WalkResult::interrupt();
78 
79     for (Value store : bufferStoresVec)
80       if (store != loadMem && mayAlias(store, loadMem))
81         return WalkResult::interrupt();
82 
83     auto write = bufferStores.find(loadMem);
84     if (write == bufferStores.end())
85       return WalkResult::advance();
86 
87     // Check that at last one store was retrieved
88     if (!write->second.size())
89       return WalkResult::interrupt();
90 
91     auto storeIndices = write->second.front();
92 
93     // Multiple writes to the same memref are allowed only on the same indices
94     for (const auto &othStoreIndices : write->second) {
95       if (othStoreIndices != storeIndices)
96         return WalkResult::interrupt();
97     }
98 
99     // Check that the load indices of secondPloop coincide with store indices of
100     // firstPloop for the same memrefs.
101     auto loadIndices = load.getIndices();
102     if (storeIndices.size() != loadIndices.size())
103       return WalkResult::interrupt();
104     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
105       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
106           loadIndices[i]) {
107         auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108         auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109         if (storeIndexDefOp && loadIndexDefOp) {
110           if (!isMemoryEffectFree(storeIndexDefOp))
111             return WalkResult::interrupt();
112           if (!isMemoryEffectFree(loadIndexDefOp))
113             return WalkResult::interrupt();
114           if (!OperationEquivalence::isEquivalentTo(
115                   storeIndexDefOp, loadIndexDefOp,
116                   [&](Value storeIndex, Value loadIndex) {
117                     if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
118                         firstToSecondPloopIndices.lookupOrDefault(loadIndex))
119                       return failure();
120                     else
121                       return success();
122                   },
123                   /*markEquivalent=*/nullptr,
124                   OperationEquivalence::Flags::IgnoreLocations)) {
125             return WalkResult::interrupt();
126           }
127         } else
128           return WalkResult::interrupt();
129       }
130     }
131     return WalkResult::advance();
132   });
133   return !walkResult.wasInterrupted();
134 }
135 
136 /// Analyzes dependencies in the most primitive way by checking simple read and
137 /// write patterns.
138 static LogicalResult
139 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
140                    const IRMapping &firstToSecondPloopIndices,
141                    llvm::function_ref<bool(Value, Value)> mayAlias) {
142   if (!haveNoReadsAfterWriteExceptSameIndex(
143           firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
144     return failure();
145 
146   IRMapping secondToFirstPloopIndices;
147   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
148                                 firstPloop.getBody()->getArguments());
149   return success(haveNoReadsAfterWriteExceptSameIndex(
150       secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
151 }
152 
153 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
154                           const IRMapping &firstToSecondPloopIndices,
155                           llvm::function_ref<bool(Value, Value)> mayAlias) {
156   return !hasNestedParallelOp(firstPloop) &&
157          !hasNestedParallelOp(secondPloop) &&
158          equalIterationSpaces(firstPloop, secondPloop) &&
159          succeeded(verifyDependencies(firstPloop, secondPloop,
160                                       firstToSecondPloopIndices, mayAlias));
161 }
162 
163 /// Prepends operations of firstPloop's body into secondPloop's body.
164 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
165                         OpBuilder b,
166                         llvm::function_ref<bool(Value, Value)> mayAlias) {
167   IRMapping firstToSecondPloopIndices;
168   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
169                                 secondPloop.getBody()->getArguments());
170 
171   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
172                      mayAlias))
173     return;
174 
175   b.setInsertionPointToStart(secondPloop.getBody());
176   for (auto &op : firstPloop.getBody()->without_terminator())
177     b.clone(op, firstToSecondPloopIndices);
178   firstPloop.erase();
179 }
180 
181 void mlir::scf::naivelyFuseParallelOps(
182     Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
183   OpBuilder b(region);
184   // Consider every single block and attempt to fuse adjacent loops.
185   for (auto &block : region) {
186     SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
187     // Not using `walk()` to traverse only top-level parallel loops and also
188     // make sure that there are no side-effecting ops between the parallel
189     // loops.
190     bool noSideEffects = true;
191     for (auto &op : block) {
192       if (auto ploop = dyn_cast<ParallelOp>(op)) {
193         if (noSideEffects) {
194           ploopChains.back().push_back(ploop);
195         } else {
196           ploopChains.push_back({ploop});
197           noSideEffects = true;
198         }
199         continue;
200       }
201       // TODO: Handle region side effects properly.
202       noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
203     }
204     for (ArrayRef<ParallelOp> ploops : ploopChains) {
205       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
206         fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
207     }
208   }
209 }
210 
211 namespace {
212 struct ParallelLoopFusion
213     : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
214   void runOnOperation() override {
215     auto &AA = getAnalysis<AliasAnalysis>();
216 
217     auto mayAlias = [&](Value val1, Value val2) -> bool {
218       return !AA.alias(val1, val2).isNo();
219     };
220 
221     getOperation()->walk([&](Operation *child) {
222       for (Region &region : child->getRegions())
223         naivelyFuseParallelOps(region, mayAlias);
224     });
225   }
226 };
227 } // namespace
228 
229 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
230   return std::make_unique<ParallelLoopFusion>();
231 }
232