xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms/Passes.h"
17 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/OpDefinition.h"
21 
22 using namespace mlir;
23 using namespace mlir::scf;
24 
25 /// Verify there are no nested ParallelOps.
26 static bool hasNestedParallelOp(ParallelOp ploop) {
27   auto walkResult =
28       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
29   return walkResult.wasInterrupted();
30 }
31 
32 /// Verify equal iteration spaces.
33 static bool equalIterationSpaces(ParallelOp firstPloop,
34                                  ParallelOp secondPloop) {
35   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
36     return false;
37 
38   auto matchOperands = [&](const OperandRange &lhs,
39                            const OperandRange &rhs) -> bool {
40     // TODO: Extend this to support aliases and equal constants.
41     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
42   };
43   return matchOperands(firstPloop.getLowerBound(),
44                        secondPloop.getLowerBound()) &&
45          matchOperands(firstPloop.getUpperBound(),
46                        secondPloop.getUpperBound()) &&
47          matchOperands(firstPloop.getStep(), secondPloop.getStep());
48 }
49 
50 /// Checks if the parallel loops have mixed access to the same buffers. Returns
51 /// `true` if the first parallel loop writes to the same indices that the second
52 /// loop reads.
53 static bool haveNoReadsAfterWriteExceptSameIndex(
54     ParallelOp firstPloop, ParallelOp secondPloop,
55     const BlockAndValueMapping &firstToSecondPloopIndices) {
56   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
57   firstPloop.getBody()->walk([&](memref::StoreOp store) {
58     bufferStores[store.getMemRef()].push_back(store.getIndices());
59   });
60   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
61     // Stop if the memref is defined in secondPloop body. Careful alias analysis
62     // is needed.
63     auto *memrefDef = load.getMemRef().getDefiningOp();
64     if (memrefDef && memrefDef->getBlock() == load->getBlock())
65       return WalkResult::interrupt();
66 
67     auto write = bufferStores.find(load.getMemRef());
68     if (write == bufferStores.end())
69       return WalkResult::advance();
70 
71     // Allow only single write access per buffer.
72     if (write->second.size() != 1)
73       return WalkResult::interrupt();
74 
75     // Check that the load indices of secondPloop coincide with store indices of
76     // firstPloop for the same memrefs.
77     auto storeIndices = write->second.front();
78     auto loadIndices = load.getIndices();
79     if (storeIndices.size() != loadIndices.size())
80       return WalkResult::interrupt();
81     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
82       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
83           loadIndices[i])
84         return WalkResult::interrupt();
85     }
86     return WalkResult::advance();
87   });
88   return !walkResult.wasInterrupted();
89 }
90 
91 /// Analyzes dependencies in the most primitive way by checking simple read and
92 /// write patterns.
93 static LogicalResult
94 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
95                    const BlockAndValueMapping &firstToSecondPloopIndices) {
96   if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
97                                             firstToSecondPloopIndices))
98     return failure();
99 
100   BlockAndValueMapping secondToFirstPloopIndices;
101   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
102                                 firstPloop.getBody()->getArguments());
103   return success(haveNoReadsAfterWriteExceptSameIndex(
104       secondPloop, firstPloop, secondToFirstPloopIndices));
105 }
106 
107 static bool
108 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
109               const BlockAndValueMapping &firstToSecondPloopIndices) {
110   return !hasNestedParallelOp(firstPloop) &&
111          !hasNestedParallelOp(secondPloop) &&
112          equalIterationSpaces(firstPloop, secondPloop) &&
113          succeeded(verifyDependencies(firstPloop, secondPloop,
114                                       firstToSecondPloopIndices));
115 }
116 
117 /// Prepends operations of firstPloop's body into secondPloop's body.
118 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
119                         OpBuilder b) {
120   BlockAndValueMapping firstToSecondPloopIndices;
121   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
122                                 secondPloop.getBody()->getArguments());
123 
124   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
125     return;
126 
127   b.setInsertionPointToStart(secondPloop.getBody());
128   for (auto &op : firstPloop.getBody()->without_terminator())
129     b.clone(op, firstToSecondPloopIndices);
130   firstPloop.erase();
131 }
132 
133 void mlir::scf::naivelyFuseParallelOps(Region &region) {
134   OpBuilder b(region);
135   // Consider every single block and attempt to fuse adjacent loops.
136   for (auto &block : region) {
137     SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
138     // Not using `walk()` to traverse only top-level parallel loops and also
139     // make sure that there are no side-effecting ops between the parallel
140     // loops.
141     bool noSideEffects = true;
142     for (auto &op : block) {
143       if (auto ploop = dyn_cast<ParallelOp>(op)) {
144         if (noSideEffects) {
145           ploopChains.back().push_back(ploop);
146         } else {
147           ploopChains.push_back({ploop});
148           noSideEffects = true;
149         }
150         continue;
151       }
152       // TODO: Handle region side effects properly.
153       noSideEffects &=
154           MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
155     }
156     for (ArrayRef<ParallelOp> ploops : ploopChains) {
157       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
158         fuseIfLegal(ploops[i], ploops[i + 1], b);
159     }
160   }
161 }
162 
163 namespace {
164 struct ParallelLoopFusion
165     : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
166   void runOnOperation() override {
167     getOperation()->walk([&](Operation *child) {
168       for (Region &region : child->getRegions())
169         naivelyFuseParallelOps(region);
170     });
171   }
172 };
173 } // namespace
174 
175 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
176   return std::make_unique<ParallelLoopFusion>();
177 }
178