xref: /llvm-project/flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp (revision 0db88db5d90f8d8a76360725597e0dfb82cc9661)
1 //===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Dialect/FIRDialect.h"
10 #include "flang/Optimizer/Dialect/FIROps.h"
11 #include "flang/Optimizer/Dialect/FIRType.h"
12 #include "flang/Optimizer/Transforms/Passes.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Transforms/Passes.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include <optional>
20 
21 namespace fir {
22 #define GEN_PASS_DEF_MEMREFDATAFLOWOPT
23 #include "flang/Optimizer/Transforms/Passes.h.inc"
24 } // namespace fir
25 
26 #define DEBUG_TYPE "fir-memref-dataflow-opt"
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 template <typename OpT>
getSpecificUsers(mlir::Value v)33 static std::vector<OpT> getSpecificUsers(mlir::Value v) {
34   std::vector<OpT> ops;
35   for (mlir::Operation *user : v.getUsers())
36     if (auto op = dyn_cast<OpT>(user))
37       ops.push_back(op);
38   return ops;
39 }
40 
41 /// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead
42 /// and AffineWrite interface
43 template <typename ReadOp, typename WriteOp>
44 class LoadStoreForwarding {
45 public:
LoadStoreForwarding(mlir::DominanceInfo * di)46   LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {}
47 
48   // FIXME: This algorithm has a bug. It ignores escaping references between a
49   // store and a load.
findStoreToForward(ReadOp loadOp,std::vector<WriteOp> && storeOps)50   std::optional<WriteOp> findStoreToForward(ReadOp loadOp,
51                                             std::vector<WriteOp> &&storeOps) {
52     llvm::SmallVector<WriteOp> candidateSet;
53 
54     for (auto storeOp : storeOps)
55       if (domInfo->dominates(storeOp, loadOp))
56         candidateSet.push_back(storeOp);
57 
58     if (candidateSet.empty())
59       return {};
60 
61     std::optional<WriteOp> nearestStore;
62     for (auto candidate : candidateSet) {
63       auto nearerThan = [&](WriteOp otherStore) {
64         if (candidate == otherStore)
65           return false;
66         bool rv = domInfo->properlyDominates(candidate, otherStore);
67         if (rv) {
68           LLVM_DEBUG(llvm::dbgs()
69                      << "candidate " << candidate << " is not the nearest to "
70                      << loadOp << " because " << otherStore << " is closer\n");
71         }
72         return rv;
73       };
74       if (!llvm::any_of(candidateSet, nearerThan)) {
75         nearestStore = mlir::cast<WriteOp>(candidate);
76         break;
77       }
78     }
79     if (!nearestStore) {
80       LLVM_DEBUG(
81           llvm::dbgs()
82           << "load " << loadOp << " has " << candidateSet.size()
83           << " store candidates, but this algorithm can't find a best.\n");
84     }
85     return nearestStore;
86   }
87 
findReadForWrite(WriteOp storeOp,std::vector<ReadOp> && loadOps)88   std::optional<ReadOp> findReadForWrite(WriteOp storeOp,
89                                          std::vector<ReadOp> &&loadOps) {
90     for (auto &loadOp : loadOps) {
91       if (domInfo->dominates(storeOp, loadOp))
92         return loadOp;
93     }
94     return {};
95   }
96 
97 private:
98   mlir::DominanceInfo *domInfo;
99 };
100 
101 class MemDataFlowOpt : public fir::impl::MemRefDataFlowOptBase<MemDataFlowOpt> {
102 public:
runOnOperation()103   void runOnOperation() override {
104     mlir::func::FuncOp f = getOperation();
105 
106     auto *domInfo = &getAnalysis<mlir::DominanceInfo>();
107     LoadStoreForwarding<fir::LoadOp, fir::StoreOp> lsf(domInfo);
108     f.walk([&](fir::LoadOp loadOp) {
109       auto maybeStore = lsf.findStoreToForward(
110           loadOp, getSpecificUsers<fir::StoreOp>(loadOp.getMemref()));
111       if (maybeStore) {
112         auto storeOp = *maybeStore;
113         LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
114                                 << " erasing load " << loadOp
115                                 << " with value from " << storeOp << '\n');
116         loadOp.getResult().replaceAllUsesWith(storeOp.getValue());
117         loadOp.erase();
118       }
119     });
120     f.walk([&](fir::AllocaOp alloca) {
121       for (auto &storeOp : getSpecificUsers<fir::StoreOp>(alloca.getResult())) {
122         if (!lsf.findReadForWrite(
123                 storeOp, getSpecificUsers<fir::LoadOp>(storeOp.getMemref()))) {
124           LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
125                                   << " erasing store " << storeOp << '\n');
126           storeOp.erase();
127         }
128       }
129     });
130   }
131 };
132 } // namespace
133 
createMemDataFlowOptPass()134 std::unique_ptr<mlir::Pass> fir::createMemDataFlowOptPass() {
135   return std::make_unique<MemDataFlowOpt>();
136 }
137