xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (revision d17b005e46e240b2c95801a14e2c6fc5baa5b3f7)
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Analysis/AliasAnalysis.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
26 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 /// Verify there are no nested ParallelOps.
33 static bool hasNestedParallelOp(ParallelOp ploop) {
34   auto walkResult =
35       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
36   return walkResult.wasInterrupted();
37 }
38 
39 /// Verify equal iteration spaces.
40 static bool equalIterationSpaces(ParallelOp firstPloop,
41                                  ParallelOp secondPloop) {
42   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
43     return false;
44 
45   auto matchOperands = [&](const OperandRange &lhs,
46                            const OperandRange &rhs) -> bool {
47     // TODO: Extend this to support aliases and equal constants.
48     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
49   };
50   return matchOperands(firstPloop.getLowerBound(),
51                        secondPloop.getLowerBound()) &&
52          matchOperands(firstPloop.getUpperBound(),
53                        secondPloop.getUpperBound()) &&
54          matchOperands(firstPloop.getStep(), secondPloop.getStep());
55 }
56 
57 /// Checks if the parallel loops have mixed access to the same buffers. Returns
58 /// `true` if the first parallel loop writes to the same indices that the second
59 /// loop reads.
60 static bool haveNoReadsAfterWriteExceptSameIndex(
61     ParallelOp firstPloop, ParallelOp secondPloop,
62     const IRMapping &firstToSecondPloopIndices,
63     llvm::function_ref<bool(Value, Value)> mayAlias) {
64   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
65   SmallVector<Value> bufferStoresVec;
66   firstPloop.getBody()->walk([&](memref::StoreOp store) {
67     bufferStores[store.getMemRef()].push_back(store.getIndices());
68     bufferStoresVec.emplace_back(store.getMemRef());
69   });
70   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
71     Value loadMem = load.getMemRef();
72     // Stop if the memref is defined in secondPloop body. Careful alias analysis
73     // is needed.
74     auto *memrefDef = loadMem.getDefiningOp();
75     if (memrefDef && memrefDef->getBlock() == load->getBlock())
76       return WalkResult::interrupt();
77 
78     for (Value store : bufferStoresVec)
79       if (store != loadMem && mayAlias(store, loadMem))
80         return WalkResult::interrupt();
81 
82     auto write = bufferStores.find(loadMem);
83     if (write == bufferStores.end())
84       return WalkResult::advance();
85 
86     // Check that at last one store was retrieved
87     if (!write->second.size())
88       return WalkResult::interrupt();
89 
90     auto storeIndices = write->second.front();
91 
92     // Multiple writes to the same memref are allowed only on the same indices
93     for (const auto &othStoreIndices : write->second) {
94       if (othStoreIndices != storeIndices)
95         return WalkResult::interrupt();
96     }
97 
98     // Check that the load indices of secondPloop coincide with store indices of
99     // firstPloop for the same memrefs.
100     auto loadIndices = load.getIndices();
101     if (storeIndices.size() != loadIndices.size())
102       return WalkResult::interrupt();
103     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
104       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
105           loadIndices[i])
106         return WalkResult::interrupt();
107     }
108     return WalkResult::advance();
109   });
110   return !walkResult.wasInterrupted();
111 }
112 
113 /// Analyzes dependencies in the most primitive way by checking simple read and
114 /// write patterns.
115 static LogicalResult
116 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
117                    const IRMapping &firstToSecondPloopIndices,
118                    llvm::function_ref<bool(Value, Value)> mayAlias) {
119   if (!haveNoReadsAfterWriteExceptSameIndex(
120           firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
121     return failure();
122 
123   IRMapping secondToFirstPloopIndices;
124   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
125                                 firstPloop.getBody()->getArguments());
126   return success(haveNoReadsAfterWriteExceptSameIndex(
127       secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
128 }
129 
130 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
131                           const IRMapping &firstToSecondPloopIndices,
132                           llvm::function_ref<bool(Value, Value)> mayAlias) {
133   return !hasNestedParallelOp(firstPloop) &&
134          !hasNestedParallelOp(secondPloop) &&
135          equalIterationSpaces(firstPloop, secondPloop) &&
136          succeeded(verifyDependencies(firstPloop, secondPloop,
137                                       firstToSecondPloopIndices, mayAlias));
138 }
139 
140 /// Prepends operations of firstPloop's body into secondPloop's body.
141 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
142                         OpBuilder b,
143                         llvm::function_ref<bool(Value, Value)> mayAlias) {
144   IRMapping firstToSecondPloopIndices;
145   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
146                                 secondPloop.getBody()->getArguments());
147 
148   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
149                      mayAlias))
150     return;
151 
152   b.setInsertionPointToStart(secondPloop.getBody());
153   for (auto &op : firstPloop.getBody()->without_terminator())
154     b.clone(op, firstToSecondPloopIndices);
155   firstPloop.erase();
156 }
157 
158 void mlir::scf::naivelyFuseParallelOps(
159     Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
160   OpBuilder b(region);
161   // Consider every single block and attempt to fuse adjacent loops.
162   for (auto &block : region) {
163     SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
164     // Not using `walk()` to traverse only top-level parallel loops and also
165     // make sure that there are no side-effecting ops between the parallel
166     // loops.
167     bool noSideEffects = true;
168     for (auto &op : block) {
169       if (auto ploop = dyn_cast<ParallelOp>(op)) {
170         if (noSideEffects) {
171           ploopChains.back().push_back(ploop);
172         } else {
173           ploopChains.push_back({ploop});
174           noSideEffects = true;
175         }
176         continue;
177       }
178       // TODO: Handle region side effects properly.
179       noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
180     }
181     for (ArrayRef<ParallelOp> ploops : ploopChains) {
182       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
183         fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
184     }
185   }
186 }
187 
188 namespace {
189 struct ParallelLoopFusion
190     : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
191   void runOnOperation() override {
192     auto &AA = getAnalysis<AliasAnalysis>();
193 
194     auto mayAlias = [&](Value val1, Value val2) -> bool {
195       return !AA.alias(val1, val2).isNo();
196     };
197 
198     getOperation()->walk([&](Operation *child) {
199       for (Region &region : child->getRegions())
200         naivelyFuseParallelOps(region, mayAlias);
201     });
202   }
203 };
204 } // namespace
205 
206 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
207   return std::make_unique<ParallelLoopFusion>();
208 }
209