xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (revision fc367dfa6770fda4adc4e5de79846f22e7e4e215)
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/OpDefinition.h"
21 #include "mlir/Interfaces/SideEffectInterfaces.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
25 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::scf;
30 
31 /// Verify there are no nested ParallelOps.
32 static bool hasNestedParallelOp(ParallelOp ploop) {
33   auto walkResult =
34       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
35   return walkResult.wasInterrupted();
36 }
37 
38 /// Verify equal iteration spaces.
39 static bool equalIterationSpaces(ParallelOp firstPloop,
40                                  ParallelOp secondPloop) {
41   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
42     return false;
43 
44   auto matchOperands = [&](const OperandRange &lhs,
45                            const OperandRange &rhs) -> bool {
46     // TODO: Extend this to support aliases and equal constants.
47     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
48   };
49   return matchOperands(firstPloop.getLowerBound(),
50                        secondPloop.getLowerBound()) &&
51          matchOperands(firstPloop.getUpperBound(),
52                        secondPloop.getUpperBound()) &&
53          matchOperands(firstPloop.getStep(), secondPloop.getStep());
54 }
55 
56 /// Checks if the parallel loops have mixed access to the same buffers. Returns
57 /// `true` if the first parallel loop writes to the same indices that the second
58 /// loop reads.
59 static bool haveNoReadsAfterWriteExceptSameIndex(
60     ParallelOp firstPloop, ParallelOp secondPloop,
61     const BlockAndValueMapping &firstToSecondPloopIndices) {
62   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
63   firstPloop.getBody()->walk([&](memref::StoreOp store) {
64     bufferStores[store.getMemRef()].push_back(store.getIndices());
65   });
66   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
67     // Stop if the memref is defined in secondPloop body. Careful alias analysis
68     // is needed.
69     auto *memrefDef = load.getMemRef().getDefiningOp();
70     if (memrefDef && memrefDef->getBlock() == load->getBlock())
71       return WalkResult::interrupt();
72 
73     auto write = bufferStores.find(load.getMemRef());
74     if (write == bufferStores.end())
75       return WalkResult::advance();
76 
77     // Allow only single write access per buffer.
78     if (write->second.size() != 1)
79       return WalkResult::interrupt();
80 
81     // Check that the load indices of secondPloop coincide with store indices of
82     // firstPloop for the same memrefs.
83     auto storeIndices = write->second.front();
84     auto loadIndices = load.getIndices();
85     if (storeIndices.size() != loadIndices.size())
86       return WalkResult::interrupt();
87     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
88       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
89           loadIndices[i])
90         return WalkResult::interrupt();
91     }
92     return WalkResult::advance();
93   });
94   return !walkResult.wasInterrupted();
95 }
96 
97 /// Analyzes dependencies in the most primitive way by checking simple read and
98 /// write patterns.
99 static LogicalResult
100 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
101                    const BlockAndValueMapping &firstToSecondPloopIndices) {
102   if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
103                                             firstToSecondPloopIndices))
104     return failure();
105 
106   BlockAndValueMapping secondToFirstPloopIndices;
107   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
108                                 firstPloop.getBody()->getArguments());
109   return success(haveNoReadsAfterWriteExceptSameIndex(
110       secondPloop, firstPloop, secondToFirstPloopIndices));
111 }
112 
113 static bool
114 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
115               const BlockAndValueMapping &firstToSecondPloopIndices) {
116   return !hasNestedParallelOp(firstPloop) &&
117          !hasNestedParallelOp(secondPloop) &&
118          equalIterationSpaces(firstPloop, secondPloop) &&
119          succeeded(verifyDependencies(firstPloop, secondPloop,
120                                       firstToSecondPloopIndices));
121 }
122 
123 /// Prepends operations of firstPloop's body into secondPloop's body.
124 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
125                         OpBuilder b) {
126   BlockAndValueMapping firstToSecondPloopIndices;
127   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
128                                 secondPloop.getBody()->getArguments());
129 
130   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
131     return;
132 
133   b.setInsertionPointToStart(secondPloop.getBody());
134   for (auto &op : firstPloop.getBody()->without_terminator())
135     b.clone(op, firstToSecondPloopIndices);
136   firstPloop.erase();
137 }
138 
139 void mlir::scf::naivelyFuseParallelOps(Region &region) {
140   OpBuilder b(region);
141   // Consider every single block and attempt to fuse adjacent loops.
142   for (auto &block : region) {
143     SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
144     // Not using `walk()` to traverse only top-level parallel loops and also
145     // make sure that there are no side-effecting ops between the parallel
146     // loops.
147     bool noSideEffects = true;
148     for (auto &op : block) {
149       if (auto ploop = dyn_cast<ParallelOp>(op)) {
150         if (noSideEffects) {
151           ploopChains.back().push_back(ploop);
152         } else {
153           ploopChains.push_back({ploop});
154           noSideEffects = true;
155         }
156         continue;
157       }
158       // TODO: Handle region side effects properly.
159       noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
160     }
161     for (ArrayRef<ParallelOp> ploops : ploopChains) {
162       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
163         fuseIfLegal(ploops[i], ploops[i + 1], b);
164     }
165   }
166 }
167 
168 namespace {
169 struct ParallelLoopFusion
170     : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
171   void runOnOperation() override {
172     getOperation()->walk([&](Operation *child) {
173       for (Region &region : child->getRegions())
174         naivelyFuseParallelOps(region);
175     });
176   }
177 };
178 } // namespace
179 
180 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
181   return std::make_unique<ParallelLoopFusion>();
182 }
183