xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (revision 8aa6c3765b924d86f623d452777eb76b83bf2787)
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/Passes.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/OpDefinition.h"
22 
23 using namespace mlir;
24 using namespace mlir::scf;
25 
26 /// Verify there are no nested ParallelOps.
27 static bool hasNestedParallelOp(ParallelOp ploop) {
28   auto walkResult =
29       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
30   return walkResult.wasInterrupted();
31 }
32 
33 /// Verify equal iteration spaces.
34 static bool equalIterationSpaces(ParallelOp firstPloop,
35                                  ParallelOp secondPloop) {
36   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
37     return false;
38 
39   auto matchOperands = [&](const OperandRange &lhs,
40                            const OperandRange &rhs) -> bool {
41     // TODO: Extend this to support aliases and equal constants.
42     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
43   };
44   return matchOperands(firstPloop.lowerBound(), secondPloop.lowerBound()) &&
45          matchOperands(firstPloop.upperBound(), secondPloop.upperBound()) &&
46          matchOperands(firstPloop.step(), secondPloop.step());
47 }
48 
49 /// Checks if the parallel loops have mixed access to the same buffers. Returns
50 /// `true` if the first parallel loop writes to the same indices that the second
51 /// loop reads.
52 static bool haveNoReadsAfterWriteExceptSameIndex(
53     ParallelOp firstPloop, ParallelOp secondPloop,
54     const BlockAndValueMapping &firstToSecondPloopIndices) {
55   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
56   firstPloop.getBody()->walk([&](memref::StoreOp store) {
57     bufferStores[store.getMemRef()].push_back(store.indices());
58   });
59   auto walkResult = secondPloop.getBody()->walk([&](LoadOp load) {
60     // Stop if the memref is defined in secondPloop body. Careful alias analysis
61     // is needed.
62     auto *memrefDef = load.getMemRef().getDefiningOp();
63     if (memrefDef && memrefDef->getBlock() == load->getBlock())
64       return WalkResult::interrupt();
65 
66     auto write = bufferStores.find(load.getMemRef());
67     if (write == bufferStores.end())
68       return WalkResult::advance();
69 
70     // Allow only single write access per buffer.
71     if (write->second.size() != 1)
72       return WalkResult::interrupt();
73 
74     // Check that the load indices of secondPloop coincide with store indices of
75     // firstPloop for the same memrefs.
76     auto storeIndices = write->second.front();
77     auto loadIndices = load.indices();
78     if (storeIndices.size() != loadIndices.size())
79       return WalkResult::interrupt();
80     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
81       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
82           loadIndices[i])
83         return WalkResult::interrupt();
84     }
85     return WalkResult::advance();
86   });
87   return !walkResult.wasInterrupted();
88 }
89 
90 /// Analyzes dependencies in the most primitive way by checking simple read and
91 /// write patterns.
92 static LogicalResult
93 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
94                    const BlockAndValueMapping &firstToSecondPloopIndices) {
95   if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
96                                             firstToSecondPloopIndices))
97     return failure();
98 
99   BlockAndValueMapping secondToFirstPloopIndices;
100   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
101                                 firstPloop.getBody()->getArguments());
102   return success(haveNoReadsAfterWriteExceptSameIndex(
103       secondPloop, firstPloop, secondToFirstPloopIndices));
104 }
105 
106 static bool
107 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
108               const BlockAndValueMapping &firstToSecondPloopIndices) {
109   return !hasNestedParallelOp(firstPloop) &&
110          !hasNestedParallelOp(secondPloop) &&
111          equalIterationSpaces(firstPloop, secondPloop) &&
112          succeeded(verifyDependencies(firstPloop, secondPloop,
113                                       firstToSecondPloopIndices));
114 }
115 
116 /// Prepends operations of firstPloop's body into secondPloop's body.
117 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
118                         OpBuilder b) {
119   BlockAndValueMapping firstToSecondPloopIndices;
120   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
121                                 secondPloop.getBody()->getArguments());
122 
123   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
124     return;
125 
126   b.setInsertionPointToStart(secondPloop.getBody());
127   for (auto &op : firstPloop.getBody()->without_terminator())
128     b.clone(op, firstToSecondPloopIndices);
129   firstPloop.erase();
130 }
131 
132 void mlir::scf::naivelyFuseParallelOps(Region &region) {
133   OpBuilder b(region);
134   // Consider every single block and attempt to fuse adjacent loops.
135   for (auto &block : region) {
136     SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
137     // Not using `walk()` to traverse only top-level parallel loops and also
138     // make sure that there are no side-effecting ops between the parallel
139     // loops.
140     bool noSideEffects = true;
141     for (auto &op : block) {
142       if (auto ploop = dyn_cast<ParallelOp>(op)) {
143         if (noSideEffects) {
144           ploopChains.back().push_back(ploop);
145         } else {
146           ploopChains.push_back({ploop});
147           noSideEffects = true;
148         }
149         continue;
150       }
151       // TODO: Handle region side effects properly.
152       noSideEffects &=
153           MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
154     }
155     for (ArrayRef<ParallelOp> ploops : ploopChains) {
156       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
157         fuseIfLegal(ploops[i], ploops[i + 1], b);
158     }
159   }
160 }
161 
162 namespace {
163 struct ParallelLoopFusion
164     : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
165   void runOnOperation() override {
166     getOperation()->walk([&](Operation *child) {
167       for (Region &region : child->getRegions())
168         naivelyFuseParallelOps(region);
169     });
170   }
171 };
172 } // namespace
173 
174 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
175   return std::make_unique<ParallelLoopFusion>();
176 }
177