xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (revision c0d2ea9d4202c7cce4214b3057a709ff2f1128ae)
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Analysis/AliasAnalysis.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
26 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 /// Verify there are no nested ParallelOps.
33 static bool hasNestedParallelOp(ParallelOp ploop) {
34   auto walkResult =
35       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
36   return walkResult.wasInterrupted();
37 }
38 
39 /// Verify equal iteration spaces.
40 static bool equalIterationSpaces(ParallelOp firstPloop,
41                                  ParallelOp secondPloop) {
42   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
43     return false;
44 
45   auto matchOperands = [&](const OperandRange &lhs,
46                            const OperandRange &rhs) -> bool {
47     // TODO: Extend this to support aliases and equal constants.
48     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
49   };
50   return matchOperands(firstPloop.getLowerBound(),
51                        secondPloop.getLowerBound()) &&
52          matchOperands(firstPloop.getUpperBound(),
53                        secondPloop.getUpperBound()) &&
54          matchOperands(firstPloop.getStep(), secondPloop.getStep());
55 }
56 
57 /// Checks if the parallel loops have mixed access to the same buffers. Returns
58 /// `true` if the first parallel loop writes to the same indices that the second
59 /// loop reads.
60 static bool haveNoReadsAfterWriteExceptSameIndex(
61     ParallelOp firstPloop, ParallelOp secondPloop,
62     const IRMapping &firstToSecondPloopIndices,
63     llvm::function_ref<bool(Value, Value)> mayAlias) {
64   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
65   SmallVector<Value> bufferStoresVec;
66   firstPloop.getBody()->walk([&](memref::StoreOp store) {
67     bufferStores[store.getMemRef()].push_back(store.getIndices());
68     bufferStoresVec.emplace_back(store.getMemRef());
69   });
70   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
71     Value loadMem = load.getMemRef();
72     // Stop if the memref is defined in secondPloop body. Careful alias analysis
73     // is needed.
74     auto *memrefDef = loadMem.getDefiningOp();
75     if (memrefDef && memrefDef->getBlock() == load->getBlock())
76       return WalkResult::interrupt();
77 
78     for (Value store : bufferStoresVec)
79       if (store != loadMem && mayAlias(store, loadMem))
80         return WalkResult::interrupt();
81 
82     auto write = bufferStores.find(loadMem);
83     if (write == bufferStores.end())
84       return WalkResult::advance();
85 
86     // Allow only single write access per buffer.
87     if (write->second.size() != 1)
88       return WalkResult::interrupt();
89 
90     // Check that the load indices of secondPloop coincide with store indices of
91     // firstPloop for the same memrefs.
92     auto storeIndices = write->second.front();
93     auto loadIndices = load.getIndices();
94     if (storeIndices.size() != loadIndices.size())
95       return WalkResult::interrupt();
96     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
97       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
98           loadIndices[i])
99         return WalkResult::interrupt();
100     }
101     return WalkResult::advance();
102   });
103   return !walkResult.wasInterrupted();
104 }
105 
106 /// Analyzes dependencies in the most primitive way by checking simple read and
107 /// write patterns.
108 static LogicalResult
109 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
110                    const IRMapping &firstToSecondPloopIndices,
111                    llvm::function_ref<bool(Value, Value)> mayAlias) {
112   if (!haveNoReadsAfterWriteExceptSameIndex(
113           firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
114     return failure();
115 
116   IRMapping secondToFirstPloopIndices;
117   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
118                                 firstPloop.getBody()->getArguments());
119   return success(haveNoReadsAfterWriteExceptSameIndex(
120       secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
121 }
122 
123 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
124                           const IRMapping &firstToSecondPloopIndices,
125                           llvm::function_ref<bool(Value, Value)> mayAlias) {
126   return !hasNestedParallelOp(firstPloop) &&
127          !hasNestedParallelOp(secondPloop) &&
128          equalIterationSpaces(firstPloop, secondPloop) &&
129          succeeded(verifyDependencies(firstPloop, secondPloop,
130                                       firstToSecondPloopIndices, mayAlias));
131 }
132 
133 /// Prepends operations of firstPloop's body into secondPloop's body.
134 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
135                         OpBuilder b,
136                         llvm::function_ref<bool(Value, Value)> mayAlias) {
137   IRMapping firstToSecondPloopIndices;
138   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
139                                 secondPloop.getBody()->getArguments());
140 
141   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
142                      mayAlias))
143     return;
144 
145   b.setInsertionPointToStart(secondPloop.getBody());
146   for (auto &op : firstPloop.getBody()->without_terminator())
147     b.clone(op, firstToSecondPloopIndices);
148   firstPloop.erase();
149 }
150 
151 void mlir::scf::naivelyFuseParallelOps(
152     Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
153   OpBuilder b(region);
154   // Consider every single block and attempt to fuse adjacent loops.
155   for (auto &block : region) {
156     SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
157     // Not using `walk()` to traverse only top-level parallel loops and also
158     // make sure that there are no side-effecting ops between the parallel
159     // loops.
160     bool noSideEffects = true;
161     for (auto &op : block) {
162       if (auto ploop = dyn_cast<ParallelOp>(op)) {
163         if (noSideEffects) {
164           ploopChains.back().push_back(ploop);
165         } else {
166           ploopChains.push_back({ploop});
167           noSideEffects = true;
168         }
169         continue;
170       }
171       // TODO: Handle region side effects properly.
172       noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
173     }
174     for (ArrayRef<ParallelOp> ploops : ploopChains) {
175       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
176         fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
177     }
178   }
179 }
180 
181 namespace {
182 struct ParallelLoopFusion
183     : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
184   void runOnOperation() override {
185     auto &AA = getAnalysis<AliasAnalysis>();
186 
187     auto mayAlias = [&](Value val1, Value val2) -> bool {
188       return !AA.alias(val1, val2).isNo();
189     };
190 
191     getOperation()->walk([&](Operation *child) {
192       for (Region &region : child->getRegions())
193         naivelyFuseParallelOps(region, mayAlias);
194     });
195   }
196 };
197 } // namespace
198 
199 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
200   return std::make_unique<ParallelLoopFusion>();
201 }
202