xref: /llvm-project/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- InlineElementals.cpp - Inline chained hlfir.elemental ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Chained elemental operations like a + b + c can inline the first elemental
9 // at the hlfir.apply in the body of the second one (as described in
10 // docs/HighLevelFIR.md). This has to be done in a pass rather than in lowering
11 // so that it happens after the HLFIR intrinsic simplification pass.
12 //===----------------------------------------------------------------------===//
13 
14 #include "flang/Optimizer/Builder/FIRBuilder.h"
15 #include "flang/Optimizer/Builder/HLFIRTools.h"
16 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
17 #include "flang/Optimizer/HLFIR/HLFIROps.h"
18 #include "flang/Optimizer/HLFIR/Passes.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Support/LLVM.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include <iterator>
28 
29 namespace hlfir {
30 #define GEN_PASS_DEF_INLINEELEMENTALS
31 #include "flang/Optimizer/HLFIR/Passes.h.inc"
32 } // namespace hlfir
33 
34 /// If the elemental has only two uses and those two are an apply operation and
35 /// a destroy operation, return those two, otherwise return {}
36 static std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>>
37 getTwoUses(hlfir::ElementalOp elemental) {
38   mlir::Operation::user_range users = elemental->getUsers();
39   // don't inline anything with more than one use (plus hfir.destroy)
40   if (std::distance(users.begin(), users.end()) != 2) {
41     return std::nullopt;
42   }
43 
44   // If the ElementalOp must produce a temporary (e.g. for
45   // finalization purposes), then we cannot inline it.
46   if (hlfir::elementalOpMustProduceTemp(elemental))
47     return std::nullopt;
48 
49   hlfir::ApplyOp apply;
50   hlfir::DestroyOp destroy;
51   for (mlir::Operation *user : users)
52     mlir::TypeSwitch<mlir::Operation *, void>(user)
53         .Case([&](hlfir::ApplyOp op) { apply = op; })
54         .Case([&](hlfir::DestroyOp op) { destroy = op; });
55 
56   if (!apply || !destroy)
57     return std::nullopt;
58 
59   // we can't inline if the return type of the yield doesn't match the return
60   // type of the apply
61   auto yield = mlir::dyn_cast_or_null<hlfir::YieldElementOp>(
62       elemental.getRegion().back().back());
63   assert(yield && "hlfir.elemental should always end with a yield");
64   if (apply.getResult().getType() != yield.getElementValue().getType())
65     return std::nullopt;
66 
67   return std::pair{apply, destroy};
68 }
69 
70 namespace {
71 class InlineElementalConversion
72     : public mlir::OpRewritePattern<hlfir::ElementalOp> {
73 public:
74   using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
75 
76   llvm::LogicalResult
77   matchAndRewrite(hlfir::ElementalOp elemental,
78                   mlir::PatternRewriter &rewriter) const override {
79     std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>> maybeTuple =
80         getTwoUses(elemental);
81     if (!maybeTuple)
82       return rewriter.notifyMatchFailure(
83           elemental, "hlfir.elemental does not have two uses");
84 
85     if (elemental.isOrdered()) {
86       // We can only inline the ordered elemental into a loop-like
87       // construct that processes the indices in-order and does not
88       // have the side effects itself. Adhere to conservative behavior
89       // for the time being.
90       return rewriter.notifyMatchFailure(elemental,
91                                          "hlfir.elemental is ordered");
92     }
93     auto [apply, destroy] = *maybeTuple;
94 
95     assert(elemental.getRegion().hasOneBlock() &&
96            "expect elemental region to have one block");
97 
98     fir::FirOpBuilder builder{rewriter, elemental.getOperation()};
99     builder.setInsertionPointAfter(apply);
100     hlfir::YieldElementOp yield = hlfir::inlineElementalOp(
101         elemental.getLoc(), builder, elemental, apply.getIndices());
102 
103     // remove the old elemental and all of the bookkeeping
104     rewriter.replaceAllUsesWith(apply.getResult(), yield.getElementValue());
105     rewriter.eraseOp(yield);
106     rewriter.eraseOp(apply);
107     rewriter.eraseOp(destroy);
108     rewriter.eraseOp(elemental);
109 
110     return mlir::success();
111   }
112 };
113 
114 class InlineElementalsPass
115     : public hlfir::impl::InlineElementalsBase<InlineElementalsPass> {
116 public:
117   void runOnOperation() override {
118     mlir::MLIRContext *context = &getContext();
119 
120     mlir::GreedyRewriteConfig config;
121     // Prevent the pattern driver from merging blocks.
122     config.enableRegionSimplification =
123         mlir::GreedySimplifyRegionLevel::Disabled;
124 
125     mlir::RewritePatternSet patterns(context);
126     patterns.insert<InlineElementalConversion>(context);
127 
128     if (mlir::failed(mlir::applyPatternsGreedily(
129             getOperation(), std::move(patterns), config))) {
130       mlir::emitError(getOperation()->getLoc(),
131                       "failure in HLFIR elemental inlining");
132       signalPassFailure();
133     }
134   }
135 };
136 } // namespace
137