1 //===- InlineElementals.cpp - Inline chained hlfir.elemental ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // Chained elemental operations like a + b + c can inline the first elemental 9 // at the hlfir.apply in the body of the second one (as described in 10 // docs/HighLevelFIR.md). This has to be done in a pass rather than in lowering 11 // so that it happens after the HLFIR intrinsic simplification pass. 12 //===----------------------------------------------------------------------===// 13 14 #include "flang/Optimizer/Builder/FIRBuilder.h" 15 #include "flang/Optimizer/Builder/HLFIRTools.h" 16 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 17 #include "flang/Optimizer/HLFIR/HLFIROps.h" 18 #include "flang/Optimizer/HLFIR/Passes.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Support/LLVM.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include <iterator> 28 29 namespace hlfir { 30 #define GEN_PASS_DEF_INLINEELEMENTALS 31 #include "flang/Optimizer/HLFIR/Passes.h.inc" 32 } // namespace hlfir 33 34 /// If the elemental has only two uses and those two are an apply operation and 35 /// a destroy operation, return those two, otherwise return {} 36 static std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>> 37 getTwoUses(hlfir::ElementalOp elemental) { 38 mlir::Operation::user_range users = elemental->getUsers(); 39 // don't inline anything with more than one use (plus hfir.destroy) 40 if (std::distance(users.begin(), users.end()) != 2) { 41 return std::nullopt; 42 } 43 44 // If the ElementalOp must produce a temporary (e.g. for 45 // finalization purposes), then we cannot inline it. 46 if (hlfir::elementalOpMustProduceTemp(elemental)) 47 return std::nullopt; 48 49 hlfir::ApplyOp apply; 50 hlfir::DestroyOp destroy; 51 for (mlir::Operation *user : users) 52 mlir::TypeSwitch<mlir::Operation *, void>(user) 53 .Case([&](hlfir::ApplyOp op) { apply = op; }) 54 .Case([&](hlfir::DestroyOp op) { destroy = op; }); 55 56 if (!apply || !destroy) 57 return std::nullopt; 58 59 // we can't inline if the return type of the yield doesn't match the return 60 // type of the apply 61 auto yield = mlir::dyn_cast_or_null<hlfir::YieldElementOp>( 62 elemental.getRegion().back().back()); 63 assert(yield && "hlfir.elemental should always end with a yield"); 64 if (apply.getResult().getType() != yield.getElementValue().getType()) 65 return std::nullopt; 66 67 return std::pair{apply, destroy}; 68 } 69 70 namespace { 71 class InlineElementalConversion 72 : public mlir::OpRewritePattern<hlfir::ElementalOp> { 73 public: 74 using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern; 75 76 llvm::LogicalResult 77 matchAndRewrite(hlfir::ElementalOp elemental, 78 mlir::PatternRewriter &rewriter) const override { 79 std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>> maybeTuple = 80 getTwoUses(elemental); 81 if (!maybeTuple) 82 return rewriter.notifyMatchFailure( 83 elemental, "hlfir.elemental does not have two uses"); 84 85 if (elemental.isOrdered()) { 86 // We can only inline the ordered elemental into a loop-like 87 // construct that processes the indices in-order and does not 88 // have the side effects itself. Adhere to conservative behavior 89 // for the time being. 90 return rewriter.notifyMatchFailure(elemental, 91 "hlfir.elemental is ordered"); 92 } 93 auto [apply, destroy] = *maybeTuple; 94 95 assert(elemental.getRegion().hasOneBlock() && 96 "expect elemental region to have one block"); 97 98 fir::FirOpBuilder builder{rewriter, elemental.getOperation()}; 99 builder.setInsertionPointAfter(apply); 100 hlfir::YieldElementOp yield = hlfir::inlineElementalOp( 101 elemental.getLoc(), builder, elemental, apply.getIndices()); 102 103 // remove the old elemental and all of the bookkeeping 104 rewriter.replaceAllUsesWith(apply.getResult(), yield.getElementValue()); 105 rewriter.eraseOp(yield); 106 rewriter.eraseOp(apply); 107 rewriter.eraseOp(destroy); 108 rewriter.eraseOp(elemental); 109 110 return mlir::success(); 111 } 112 }; 113 114 class InlineElementalsPass 115 : public hlfir::impl::InlineElementalsBase<InlineElementalsPass> { 116 public: 117 void runOnOperation() override { 118 mlir::MLIRContext *context = &getContext(); 119 120 mlir::GreedyRewriteConfig config; 121 // Prevent the pattern driver from merging blocks. 122 config.enableRegionSimplification = 123 mlir::GreedySimplifyRegionLevel::Disabled; 124 125 mlir::RewritePatternSet patterns(context); 126 patterns.insert<InlineElementalConversion>(context); 127 128 if (mlir::failed(mlir::applyPatternsGreedily( 129 getOperation(), std::move(patterns), config))) { 130 mlir::emitError(getOperation()->getLoc(), 131 "failure in HLFIR elemental inlining"); 132 signalPassFailure(); 133 } 134 } 135 }; 136 } // namespace 137