xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (revision c0342a2de8aa9ffe129bce50e9c90647a772bb2b)
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/Passes.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/OpDefinition.h"
22 
23 using namespace mlir;
24 using namespace mlir::scf;
25 
26 /// Verify there are no nested ParallelOps.
27 static bool hasNestedParallelOp(ParallelOp ploop) {
28   auto walkResult =
29       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
30   return walkResult.wasInterrupted();
31 }
32 
33 /// Verify equal iteration spaces.
34 static bool equalIterationSpaces(ParallelOp firstPloop,
35                                  ParallelOp secondPloop) {
36   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
37     return false;
38 
39   auto matchOperands = [&](const OperandRange &lhs,
40                            const OperandRange &rhs) -> bool {
41     // TODO: Extend this to support aliases and equal constants.
42     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
43   };
44   return matchOperands(firstPloop.getLowerBound(),
45                        secondPloop.getLowerBound()) &&
46          matchOperands(firstPloop.getUpperBound(),
47                        secondPloop.getUpperBound()) &&
48          matchOperands(firstPloop.getStep(), secondPloop.getStep());
49 }
50 
51 /// Checks if the parallel loops have mixed access to the same buffers. Returns
52 /// `true` if the first parallel loop writes to the same indices that the second
53 /// loop reads.
54 static bool haveNoReadsAfterWriteExceptSameIndex(
55     ParallelOp firstPloop, ParallelOp secondPloop,
56     const BlockAndValueMapping &firstToSecondPloopIndices) {
57   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
58   firstPloop.getBody()->walk([&](memref::StoreOp store) {
59     bufferStores[store.getMemRef()].push_back(store.indices());
60   });
61   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
62     // Stop if the memref is defined in secondPloop body. Careful alias analysis
63     // is needed.
64     auto *memrefDef = load.getMemRef().getDefiningOp();
65     if (memrefDef && memrefDef->getBlock() == load->getBlock())
66       return WalkResult::interrupt();
67 
68     auto write = bufferStores.find(load.getMemRef());
69     if (write == bufferStores.end())
70       return WalkResult::advance();
71 
72     // Allow only single write access per buffer.
73     if (write->second.size() != 1)
74       return WalkResult::interrupt();
75 
76     // Check that the load indices of secondPloop coincide with store indices of
77     // firstPloop for the same memrefs.
78     auto storeIndices = write->second.front();
79     auto loadIndices = load.indices();
80     if (storeIndices.size() != loadIndices.size())
81       return WalkResult::interrupt();
82     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
83       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
84           loadIndices[i])
85         return WalkResult::interrupt();
86     }
87     return WalkResult::advance();
88   });
89   return !walkResult.wasInterrupted();
90 }
91 
92 /// Analyzes dependencies in the most primitive way by checking simple read and
93 /// write patterns.
94 static LogicalResult
95 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
96                    const BlockAndValueMapping &firstToSecondPloopIndices) {
97   if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
98                                             firstToSecondPloopIndices))
99     return failure();
100 
101   BlockAndValueMapping secondToFirstPloopIndices;
102   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
103                                 firstPloop.getBody()->getArguments());
104   return success(haveNoReadsAfterWriteExceptSameIndex(
105       secondPloop, firstPloop, secondToFirstPloopIndices));
106 }
107 
108 static bool
109 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
110               const BlockAndValueMapping &firstToSecondPloopIndices) {
111   return !hasNestedParallelOp(firstPloop) &&
112          !hasNestedParallelOp(secondPloop) &&
113          equalIterationSpaces(firstPloop, secondPloop) &&
114          succeeded(verifyDependencies(firstPloop, secondPloop,
115                                       firstToSecondPloopIndices));
116 }
117 
118 /// Prepends operations of firstPloop's body into secondPloop's body.
119 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
120                         OpBuilder b) {
121   BlockAndValueMapping firstToSecondPloopIndices;
122   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
123                                 secondPloop.getBody()->getArguments());
124 
125   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
126     return;
127 
128   b.setInsertionPointToStart(secondPloop.getBody());
129   for (auto &op : firstPloop.getBody()->without_terminator())
130     b.clone(op, firstToSecondPloopIndices);
131   firstPloop.erase();
132 }
133 
134 void mlir::scf::naivelyFuseParallelOps(Region &region) {
135   OpBuilder b(region);
136   // Consider every single block and attempt to fuse adjacent loops.
137   for (auto &block : region) {
138     SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
139     // Not using `walk()` to traverse only top-level parallel loops and also
140     // make sure that there are no side-effecting ops between the parallel
141     // loops.
142     bool noSideEffects = true;
143     for (auto &op : block) {
144       if (auto ploop = dyn_cast<ParallelOp>(op)) {
145         if (noSideEffects) {
146           ploopChains.back().push_back(ploop);
147         } else {
148           ploopChains.push_back({ploop});
149           noSideEffects = true;
150         }
151         continue;
152       }
153       // TODO: Handle region side effects properly.
154       noSideEffects &=
155           MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
156     }
157     for (ArrayRef<ParallelOp> ploops : ploopChains) {
158       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
159         fuseIfLegal(ploops[i], ploops[i + 1], b);
160     }
161   }
162 }
163 
164 namespace {
165 struct ParallelLoopFusion
166     : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
167   void runOnOperation() override {
168     getOperation()->walk([&](Operation *child) {
169       for (Region &region : child->getRegions())
170         naivelyFuseParallelOps(region);
171     });
172   }
173 };
174 } // namespace
175 
176 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
177   return std::make_unique<ParallelLoopFusion>();
178 }
179