xref: /llvm-project/flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp (revision 0db88db5d90f8d8a76360725597e0dfb82cc9661)
12e7202b0SValentin Clement //===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===//
22e7202b0SValentin Clement //
32e7202b0SValentin Clement // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42e7202b0SValentin Clement // See https://llvm.org/LICENSE.txt for license information.
52e7202b0SValentin Clement // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62e7202b0SValentin Clement //
72e7202b0SValentin Clement //===----------------------------------------------------------------------===//
82e7202b0SValentin Clement 
92e7202b0SValentin Clement #include "flang/Optimizer/Dialect/FIRDialect.h"
102e7202b0SValentin Clement #include "flang/Optimizer/Dialect/FIROps.h"
112e7202b0SValentin Clement #include "flang/Optimizer/Dialect/FIRType.h"
122e7202b0SValentin Clement #include "flang/Optimizer/Transforms/Passes.h"
1323aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
142e7202b0SValentin Clement #include "mlir/IR/Dominance.h"
152e7202b0SValentin Clement #include "mlir/IR/Operation.h"
162e7202b0SValentin Clement #include "mlir/Transforms/Passes.h"
172e7202b0SValentin Clement #include "llvm/ADT/STLExtras.h"
182e7202b0SValentin Clement #include "llvm/ADT/SmallVector.h"
194d4d4785SKazu Hirata #include <optional>
202e7202b0SValentin Clement 
2167d0d7acSMichele Scuttari namespace fir {
2267d0d7acSMichele Scuttari #define GEN_PASS_DEF_MEMREFDATAFLOWOPT
2367d0d7acSMichele Scuttari #include "flang/Optimizer/Transforms/Passes.h.inc"
2467d0d7acSMichele Scuttari } // namespace fir
2567d0d7acSMichele Scuttari 
262e7202b0SValentin Clement #define DEBUG_TYPE "fir-memref-dataflow-opt"
272e7202b0SValentin Clement 
28092601d4SAndrzej Warzynski using namespace mlir;
29092601d4SAndrzej Warzynski 
302e7202b0SValentin Clement namespace {
312e7202b0SValentin Clement 
322e7202b0SValentin Clement template <typename OpT>
getSpecificUsers(mlir::Value v)332e7202b0SValentin Clement static std::vector<OpT> getSpecificUsers(mlir::Value v) {
342e7202b0SValentin Clement   std::vector<OpT> ops;
352e7202b0SValentin Clement   for (mlir::Operation *user : v.getUsers())
362e7202b0SValentin Clement     if (auto op = dyn_cast<OpT>(user))
372e7202b0SValentin Clement       ops.push_back(op);
382e7202b0SValentin Clement   return ops;
392e7202b0SValentin Clement }
402e7202b0SValentin Clement 
412e7202b0SValentin Clement /// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead
422e7202b0SValentin Clement /// and AffineWrite interface
432e7202b0SValentin Clement template <typename ReadOp, typename WriteOp>
442e7202b0SValentin Clement class LoadStoreForwarding {
452e7202b0SValentin Clement public:
LoadStoreForwarding(mlir::DominanceInfo * di)462e7202b0SValentin Clement   LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {}
472e7202b0SValentin Clement 
482e7202b0SValentin Clement   // FIXME: This algorithm has a bug. It ignores escaping references between a
492e7202b0SValentin Clement   // store and a load.
findStoreToForward(ReadOp loadOp,std::vector<WriteOp> && storeOps)50*c0921586SKazu Hirata   std::optional<WriteOp> findStoreToForward(ReadOp loadOp,
512e7202b0SValentin Clement                                             std::vector<WriteOp> &&storeOps) {
522e7202b0SValentin Clement     llvm::SmallVector<WriteOp> candidateSet;
532e7202b0SValentin Clement 
542e7202b0SValentin Clement     for (auto storeOp : storeOps)
552e7202b0SValentin Clement       if (domInfo->dominates(storeOp, loadOp))
562e7202b0SValentin Clement         candidateSet.push_back(storeOp);
572e7202b0SValentin Clement 
582e7202b0SValentin Clement     if (candidateSet.empty())
592e7202b0SValentin Clement       return {};
602e7202b0SValentin Clement 
61*c0921586SKazu Hirata     std::optional<WriteOp> nearestStore;
622e7202b0SValentin Clement     for (auto candidate : candidateSet) {
632e7202b0SValentin Clement       auto nearerThan = [&](WriteOp otherStore) {
642e7202b0SValentin Clement         if (candidate == otherStore)
652e7202b0SValentin Clement           return false;
662e7202b0SValentin Clement         bool rv = domInfo->properlyDominates(candidate, otherStore);
672e7202b0SValentin Clement         if (rv) {
682e7202b0SValentin Clement           LLVM_DEBUG(llvm::dbgs()
692e7202b0SValentin Clement                      << "candidate " << candidate << " is not the nearest to "
702e7202b0SValentin Clement                      << loadOp << " because " << otherStore << " is closer\n");
712e7202b0SValentin Clement         }
722e7202b0SValentin Clement         return rv;
732e7202b0SValentin Clement       };
742e7202b0SValentin Clement       if (!llvm::any_of(candidateSet, nearerThan)) {
752e7202b0SValentin Clement         nearestStore = mlir::cast<WriteOp>(candidate);
762e7202b0SValentin Clement         break;
772e7202b0SValentin Clement       }
782e7202b0SValentin Clement     }
792e7202b0SValentin Clement     if (!nearestStore) {
802e7202b0SValentin Clement       LLVM_DEBUG(
812e7202b0SValentin Clement           llvm::dbgs()
822e7202b0SValentin Clement           << "load " << loadOp << " has " << candidateSet.size()
832e7202b0SValentin Clement           << " store candidates, but this algorithm can't find a best.\n");
842e7202b0SValentin Clement     }
852e7202b0SValentin Clement     return nearestStore;
862e7202b0SValentin Clement   }
872e7202b0SValentin Clement 
findReadForWrite(WriteOp storeOp,std::vector<ReadOp> && loadOps)88*c0921586SKazu Hirata   std::optional<ReadOp> findReadForWrite(WriteOp storeOp,
892e7202b0SValentin Clement                                          std::vector<ReadOp> &&loadOps) {
902e7202b0SValentin Clement     for (auto &loadOp : loadOps) {
912e7202b0SValentin Clement       if (domInfo->dominates(storeOp, loadOp))
922e7202b0SValentin Clement         return loadOp;
932e7202b0SValentin Clement     }
942e7202b0SValentin Clement     return {};
952e7202b0SValentin Clement   }
962e7202b0SValentin Clement 
972e7202b0SValentin Clement private:
982e7202b0SValentin Clement   mlir::DominanceInfo *domInfo;
992e7202b0SValentin Clement };
1002e7202b0SValentin Clement 
10167d0d7acSMichele Scuttari class MemDataFlowOpt : public fir::impl::MemRefDataFlowOptBase<MemDataFlowOpt> {
1022e7202b0SValentin Clement public:
runOnOperation()103196c4279SRiver Riddle   void runOnOperation() override {
10458ceae95SRiver Riddle     mlir::func::FuncOp f = getOperation();
1052e7202b0SValentin Clement 
1062e7202b0SValentin Clement     auto *domInfo = &getAnalysis<mlir::DominanceInfo>();
1072e7202b0SValentin Clement     LoadStoreForwarding<fir::LoadOp, fir::StoreOp> lsf(domInfo);
1082e7202b0SValentin Clement     f.walk([&](fir::LoadOp loadOp) {
1092e7202b0SValentin Clement       auto maybeStore = lsf.findStoreToForward(
110149ad3d5SShraiysh Vaishay           loadOp, getSpecificUsers<fir::StoreOp>(loadOp.getMemref()));
1112e7202b0SValentin Clement       if (maybeStore) {
112ed8fceaaSKazu Hirata         auto storeOp = *maybeStore;
1132e7202b0SValentin Clement         LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
1142e7202b0SValentin Clement                                 << " erasing load " << loadOp
1152e7202b0SValentin Clement                                 << " with value from " << storeOp << '\n');
116149ad3d5SShraiysh Vaishay         loadOp.getResult().replaceAllUsesWith(storeOp.getValue());
1172e7202b0SValentin Clement         loadOp.erase();
1182e7202b0SValentin Clement       }
1192e7202b0SValentin Clement     });
1202e7202b0SValentin Clement     f.walk([&](fir::AllocaOp alloca) {
1212e7202b0SValentin Clement       for (auto &storeOp : getSpecificUsers<fir::StoreOp>(alloca.getResult())) {
1222e7202b0SValentin Clement         if (!lsf.findReadForWrite(
123149ad3d5SShraiysh Vaishay                 storeOp, getSpecificUsers<fir::LoadOp>(storeOp.getMemref()))) {
1242e7202b0SValentin Clement           LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
1252e7202b0SValentin Clement                                   << " erasing store " << storeOp << '\n');
1262e7202b0SValentin Clement           storeOp.erase();
1272e7202b0SValentin Clement         }
1282e7202b0SValentin Clement       }
1292e7202b0SValentin Clement     });
1302e7202b0SValentin Clement   }
1312e7202b0SValentin Clement };
1322e7202b0SValentin Clement } // namespace
1332e7202b0SValentin Clement 
createMemDataFlowOptPass()1342e7202b0SValentin Clement std::unique_ptr<mlir::Pass> fir::createMemDataFlowOptPass() {
1352e7202b0SValentin Clement   return std::make_unique<MemDataFlowOpt>();
1362e7202b0SValentin Clement }
137