xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (revision 6820b0871807abff07df118659e0de2ca741cb0b)
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Analysis/AliasAnalysis.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/Interfaces/SideEffectInterfaces.h"
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
28 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::scf;
33 
34 /// Verify there are no nested ParallelOps.
35 static bool hasNestedParallelOp(ParallelOp ploop) {
36   auto walkResult =
37       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
38   return walkResult.wasInterrupted();
39 }
40 
41 /// Checks if the parallel loops have mixed access to the same buffers. Returns
42 /// `true` if the first parallel loop writes to the same indices that the second
43 /// loop reads.
44 static bool haveNoReadsAfterWriteExceptSameIndex(
45     ParallelOp firstPloop, ParallelOp secondPloop,
46     const IRMapping &firstToSecondPloopIndices,
47     llvm::function_ref<bool(Value, Value)> mayAlias) {
48   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
49   SmallVector<Value> bufferStoresVec;
50   firstPloop.getBody()->walk([&](memref::StoreOp store) {
51     bufferStores[store.getMemRef()].push_back(store.getIndices());
52     bufferStoresVec.emplace_back(store.getMemRef());
53   });
54   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
55     Value loadMem = load.getMemRef();
56     // Stop if the memref is defined in secondPloop body. Careful alias analysis
57     // is needed.
58     auto *memrefDef = loadMem.getDefiningOp();
59     if (memrefDef && memrefDef->getBlock() == load->getBlock())
60       return WalkResult::interrupt();
61 
62     for (Value store : bufferStoresVec)
63       if (store != loadMem && mayAlias(store, loadMem))
64         return WalkResult::interrupt();
65 
66     auto write = bufferStores.find(loadMem);
67     if (write == bufferStores.end())
68       return WalkResult::advance();
69 
70     // Check that at last one store was retrieved
71     if (!write->second.size())
72       return WalkResult::interrupt();
73 
74     auto storeIndices = write->second.front();
75 
76     // Multiple writes to the same memref are allowed only on the same indices
77     for (const auto &othStoreIndices : write->second) {
78       if (othStoreIndices != storeIndices)
79         return WalkResult::interrupt();
80     }
81 
82     // Check that the load indices of secondPloop coincide with store indices of
83     // firstPloop for the same memrefs.
84     auto loadIndices = load.getIndices();
85     if (storeIndices.size() != loadIndices.size())
86       return WalkResult::interrupt();
87     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
88       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
89           loadIndices[i]) {
90         auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
91         auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
92         if (storeIndexDefOp && loadIndexDefOp) {
93           if (!isMemoryEffectFree(storeIndexDefOp))
94             return WalkResult::interrupt();
95           if (!isMemoryEffectFree(loadIndexDefOp))
96             return WalkResult::interrupt();
97           if (!OperationEquivalence::isEquivalentTo(
98                   storeIndexDefOp, loadIndexDefOp,
99                   [&](Value storeIndex, Value loadIndex) {
100                     if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
101                         firstToSecondPloopIndices.lookupOrDefault(loadIndex))
102                       return failure();
103                     else
104                       return success();
105                   },
106                   /*markEquivalent=*/nullptr,
107                   OperationEquivalence::Flags::IgnoreLocations)) {
108             return WalkResult::interrupt();
109           }
110         } else
111           return WalkResult::interrupt();
112       }
113     }
114     return WalkResult::advance();
115   });
116   return !walkResult.wasInterrupted();
117 }
118 
119 /// Analyzes dependencies in the most primitive way by checking simple read and
120 /// write patterns.
121 static LogicalResult
122 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
123                    const IRMapping &firstToSecondPloopIndices,
124                    llvm::function_ref<bool(Value, Value)> mayAlias) {
125   if (!haveNoReadsAfterWriteExceptSameIndex(
126           firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
127     return failure();
128 
129   IRMapping secondToFirstPloopIndices;
130   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
131                                 firstPloop.getBody()->getArguments());
132   return success(haveNoReadsAfterWriteExceptSameIndex(
133       secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
134 }
135 
136 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
137                           const IRMapping &firstToSecondPloopIndices,
138                           llvm::function_ref<bool(Value, Value)> mayAlias) {
139   Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
140   return !hasNestedParallelOp(firstPloop) &&
141          !hasNestedParallelOp(secondPloop) &&
142          checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
143          succeeded(verifyDependencies(firstPloop, secondPloop,
144                                       firstToSecondPloopIndices, mayAlias));
145 }
146 
147 /// Prepends operations of firstPloop's body into secondPloop's body.
148 /// Updates secondPloop with new loop.
149 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
150                         OpBuilder builder,
151                         llvm::function_ref<bool(Value, Value)> mayAlias) {
152   Block *block1 = firstPloop.getBody();
153   Block *block2 = secondPloop.getBody();
154   IRMapping firstToSecondPloopIndices;
155   firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
156 
157   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
158                      mayAlias))
159     return;
160 
161   IRRewriter rewriter(builder);
162   secondPloop = mlir::fuseIndependentSiblingParallelLoops(
163       firstPloop, secondPloop, rewriter);
164 }
165 
166 void mlir::scf::naivelyFuseParallelOps(
167     Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
168   OpBuilder b(region);
169   // Consider every single block and attempt to fuse adjacent loops.
170   SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
171   for (auto &block : region) {
172     ploopChains.clear();
173     ploopChains.push_back({});
174 
175     // Not using `walk()` to traverse only top-level parallel loops and also
176     // make sure that there are no side-effecting ops between the parallel
177     // loops.
178     bool noSideEffects = true;
179     for (auto &op : block) {
180       if (auto ploop = dyn_cast<ParallelOp>(op)) {
181         if (noSideEffects) {
182           ploopChains.back().push_back(ploop);
183         } else {
184           ploopChains.push_back({ploop});
185           noSideEffects = true;
186         }
187         continue;
188       }
189       // TODO: Handle region side effects properly.
190       noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
191     }
192     for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
193       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
194         fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
195     }
196   }
197 }
198 
199 namespace {
200 struct ParallelLoopFusion
201     : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
202   void runOnOperation() override {
203     auto &AA = getAnalysis<AliasAnalysis>();
204 
205     auto mayAlias = [&](Value val1, Value val2) -> bool {
206       return !AA.alias(val1, val2).isNo();
207     };
208 
209     getOperation()->walk([&](Operation *child) {
210       for (Region &region : child->getRegions())
211         naivelyFuseParallelOps(region, mayAlias);
212     });
213   }
214 };
215 } // namespace
216 
217 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
218   return std::make_unique<ParallelLoopFusion>();
219 }
220