xref: /llvm-project/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp (revision 3c700d131a35ce4b0063a4688dce4a0cb739ca83)
1*3c700d13SSlava Zakharin //===- InlineHLFIRAssign.cpp - Inline hlfir.assign ops --------------------===//
2*3c700d13SSlava Zakharin //
3*3c700d13SSlava Zakharin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*3c700d13SSlava Zakharin // See https://llvm.org/LICENSE.txt for license information.
5*3c700d13SSlava Zakharin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*3c700d13SSlava Zakharin //
7*3c700d13SSlava Zakharin //===----------------------------------------------------------------------===//
8*3c700d13SSlava Zakharin // Transform hlfir.assign array operations into loop nests performing element
9*3c700d13SSlava Zakharin // per element assignments. The inlining is done for trivial data types always,
10*3c700d13SSlava Zakharin // though, we may add performance/code-size heuristics in future.
11*3c700d13SSlava Zakharin //===----------------------------------------------------------------------===//
12*3c700d13SSlava Zakharin 
13*3c700d13SSlava Zakharin #include "flang/Optimizer/Analysis/AliasAnalysis.h"
14*3c700d13SSlava Zakharin #include "flang/Optimizer/Builder/FIRBuilder.h"
15*3c700d13SSlava Zakharin #include "flang/Optimizer/Builder/HLFIRTools.h"
16*3c700d13SSlava Zakharin #include "flang/Optimizer/HLFIR/HLFIROps.h"
17*3c700d13SSlava Zakharin #include "flang/Optimizer/HLFIR/Passes.h"
18*3c700d13SSlava Zakharin #include "flang/Optimizer/OpenMP/Passes.h"
19*3c700d13SSlava Zakharin #include "mlir/IR/PatternMatch.h"
20*3c700d13SSlava Zakharin #include "mlir/Pass/Pass.h"
21*3c700d13SSlava Zakharin #include "mlir/Support/LLVM.h"
22*3c700d13SSlava Zakharin #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23*3c700d13SSlava Zakharin 
24*3c700d13SSlava Zakharin namespace hlfir {
25*3c700d13SSlava Zakharin #define GEN_PASS_DEF_INLINEHLFIRASSIGN
26*3c700d13SSlava Zakharin #include "flang/Optimizer/HLFIR/Passes.h.inc"
27*3c700d13SSlava Zakharin } // namespace hlfir
28*3c700d13SSlava Zakharin 
29*3c700d13SSlava Zakharin #define DEBUG_TYPE "inline-hlfir-assign"
30*3c700d13SSlava Zakharin 
31*3c700d13SSlava Zakharin namespace {
32*3c700d13SSlava Zakharin /// Expand hlfir.assign of array RHS to array LHS into a loop nest
33*3c700d13SSlava Zakharin /// of element-by-element assignments:
34*3c700d13SSlava Zakharin ///   hlfir.assign %4 to %5 : !fir.ref<!fir.array<3x3xf32>>,
35*3c700d13SSlava Zakharin ///                           !fir.ref<!fir.array<3x3xf32>>
36*3c700d13SSlava Zakharin /// into:
37*3c700d13SSlava Zakharin ///   fir.do_loop %arg1 = %c1 to %c3 step %c1 unordered {
38*3c700d13SSlava Zakharin ///     fir.do_loop %arg2 = %c1 to %c3 step %c1 unordered {
39*3c700d13SSlava Zakharin ///       %6 = hlfir.designate %4 (%arg2, %arg1)  :
40*3c700d13SSlava Zakharin ///           (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32>
41*3c700d13SSlava Zakharin ///       %7 = fir.load %6 : !fir.ref<f32>
42*3c700d13SSlava Zakharin ///       %8 = hlfir.designate %5 (%arg2, %arg1)  :
43*3c700d13SSlava Zakharin ///           (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32>
44*3c700d13SSlava Zakharin ///       hlfir.assign %7 to %8 : f32, !fir.ref<f32>
45*3c700d13SSlava Zakharin ///     }
46*3c700d13SSlava Zakharin ///   }
47*3c700d13SSlava Zakharin ///
48*3c700d13SSlava Zakharin /// The transformation is correct only when LHS and RHS do not alias.
49*3c700d13SSlava Zakharin /// When RHS is an array expression, then there is no aliasing.
50*3c700d13SSlava Zakharin /// This transformation does not support runtime checking for
51*3c700d13SSlava Zakharin /// non-conforming LHS/RHS arrays' shapes currently.
52*3c700d13SSlava Zakharin class InlineHLFIRAssignConversion
53*3c700d13SSlava Zakharin     : public mlir::OpRewritePattern<hlfir::AssignOp> {
54*3c700d13SSlava Zakharin public:
55*3c700d13SSlava Zakharin   using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern;
56*3c700d13SSlava Zakharin 
57*3c700d13SSlava Zakharin   llvm::LogicalResult
58*3c700d13SSlava Zakharin   matchAndRewrite(hlfir::AssignOp assign,
59*3c700d13SSlava Zakharin                   mlir::PatternRewriter &rewriter) const override {
60*3c700d13SSlava Zakharin     if (assign.isAllocatableAssignment())
61*3c700d13SSlava Zakharin       return rewriter.notifyMatchFailure(assign,
62*3c700d13SSlava Zakharin                                          "AssignOp may imply allocation");
63*3c700d13SSlava Zakharin 
64*3c700d13SSlava Zakharin     hlfir::Entity rhs{assign.getRhs()};
65*3c700d13SSlava Zakharin 
66*3c700d13SSlava Zakharin     if (!rhs.isArray())
67*3c700d13SSlava Zakharin       return rewriter.notifyMatchFailure(assign,
68*3c700d13SSlava Zakharin                                          "AssignOp's RHS is not an array");
69*3c700d13SSlava Zakharin 
70*3c700d13SSlava Zakharin     mlir::Type rhsEleTy = rhs.getFortranElementType();
71*3c700d13SSlava Zakharin     if (!fir::isa_trivial(rhsEleTy))
72*3c700d13SSlava Zakharin       return rewriter.notifyMatchFailure(
73*3c700d13SSlava Zakharin           assign, "AssignOp's RHS data type is not trivial");
74*3c700d13SSlava Zakharin 
75*3c700d13SSlava Zakharin     hlfir::Entity lhs{assign.getLhs()};
76*3c700d13SSlava Zakharin     if (!lhs.isArray())
77*3c700d13SSlava Zakharin       return rewriter.notifyMatchFailure(assign,
78*3c700d13SSlava Zakharin                                          "AssignOp's LHS is not an array");
79*3c700d13SSlava Zakharin 
80*3c700d13SSlava Zakharin     mlir::Type lhsEleTy = lhs.getFortranElementType();
81*3c700d13SSlava Zakharin     if (!fir::isa_trivial(lhsEleTy))
82*3c700d13SSlava Zakharin       return rewriter.notifyMatchFailure(
83*3c700d13SSlava Zakharin           assign, "AssignOp's LHS data type is not trivial");
84*3c700d13SSlava Zakharin 
85*3c700d13SSlava Zakharin     if (lhsEleTy != rhsEleTy)
86*3c700d13SSlava Zakharin       return rewriter.notifyMatchFailure(assign,
87*3c700d13SSlava Zakharin                                          "RHS/LHS element types mismatch");
88*3c700d13SSlava Zakharin 
89*3c700d13SSlava Zakharin     if (!mlir::isa<hlfir::ExprType>(rhs.getType())) {
90*3c700d13SSlava Zakharin       // If RHS is not an hlfir.expr, then we should prove that
91*3c700d13SSlava Zakharin       // LHS and RHS do not alias.
92*3c700d13SSlava Zakharin       // TODO: if they may alias, we can insert hlfir.as_expr for RHS,
93*3c700d13SSlava Zakharin       // and proceed with the inlining.
94*3c700d13SSlava Zakharin       fir::AliasAnalysis aliasAnalysis;
95*3c700d13SSlava Zakharin       mlir::AliasResult aliasRes = aliasAnalysis.alias(lhs, rhs);
96*3c700d13SSlava Zakharin       // TODO: use areIdenticalOrDisjointSlices() from
97*3c700d13SSlava Zakharin       // OptimizedBufferization.cpp to check if we can still do the expansion.
98*3c700d13SSlava Zakharin       if (!aliasRes.isNo()) {
99*3c700d13SSlava Zakharin         LLVM_DEBUG(llvm::dbgs() << "InlineHLFIRAssign:\n"
100*3c700d13SSlava Zakharin                                 << "\tLHS: " << lhs << "\n"
101*3c700d13SSlava Zakharin                                 << "\tRHS: " << rhs << "\n"
102*3c700d13SSlava Zakharin                                 << "\tALIAS: " << aliasRes << "\n");
103*3c700d13SSlava Zakharin         return rewriter.notifyMatchFailure(assign, "RHS/LHS may alias");
104*3c700d13SSlava Zakharin       }
105*3c700d13SSlava Zakharin     }
106*3c700d13SSlava Zakharin 
107*3c700d13SSlava Zakharin     mlir::Location loc = assign->getLoc();
108*3c700d13SSlava Zakharin     fir::FirOpBuilder builder(rewriter, assign.getOperation());
109*3c700d13SSlava Zakharin     builder.setInsertionPoint(assign);
110*3c700d13SSlava Zakharin     rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs);
111*3c700d13SSlava Zakharin     lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
112*3c700d13SSlava Zakharin     mlir::Value shape = hlfir::genShape(loc, builder, lhs);
113*3c700d13SSlava Zakharin     llvm::SmallVector<mlir::Value> extents =
114*3c700d13SSlava Zakharin         hlfir::getIndexExtents(loc, builder, shape);
115*3c700d13SSlava Zakharin     hlfir::LoopNest loopNest =
116*3c700d13SSlava Zakharin         hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
117*3c700d13SSlava Zakharin                            flangomp::shouldUseWorkshareLowering(assign));
118*3c700d13SSlava Zakharin     builder.setInsertionPointToStart(loopNest.body);
119*3c700d13SSlava Zakharin     auto rhsArrayElement =
120*3c700d13SSlava Zakharin         hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
121*3c700d13SSlava Zakharin     rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);
122*3c700d13SSlava Zakharin     auto lhsArrayElement =
123*3c700d13SSlava Zakharin         hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
124*3c700d13SSlava Zakharin     builder.create<hlfir::AssignOp>(loc, rhsArrayElement, lhsArrayElement);
125*3c700d13SSlava Zakharin     rewriter.eraseOp(assign);
126*3c700d13SSlava Zakharin     return mlir::success();
127*3c700d13SSlava Zakharin   }
128*3c700d13SSlava Zakharin };
129*3c700d13SSlava Zakharin 
130*3c700d13SSlava Zakharin class InlineHLFIRAssignPass
131*3c700d13SSlava Zakharin     : public hlfir::impl::InlineHLFIRAssignBase<InlineHLFIRAssignPass> {
132*3c700d13SSlava Zakharin public:
133*3c700d13SSlava Zakharin   void runOnOperation() override {
134*3c700d13SSlava Zakharin     mlir::MLIRContext *context = &getContext();
135*3c700d13SSlava Zakharin 
136*3c700d13SSlava Zakharin     mlir::GreedyRewriteConfig config;
137*3c700d13SSlava Zakharin     // Prevent the pattern driver from merging blocks.
138*3c700d13SSlava Zakharin     config.enableRegionSimplification =
139*3c700d13SSlava Zakharin         mlir::GreedySimplifyRegionLevel::Disabled;
140*3c700d13SSlava Zakharin 
141*3c700d13SSlava Zakharin     mlir::RewritePatternSet patterns(context);
142*3c700d13SSlava Zakharin     patterns.insert<InlineHLFIRAssignConversion>(context);
143*3c700d13SSlava Zakharin 
144*3c700d13SSlava Zakharin     if (mlir::failed(mlir::applyPatternsGreedily(
145*3c700d13SSlava Zakharin             getOperation(), std::move(patterns), config))) {
146*3c700d13SSlava Zakharin       mlir::emitError(getOperation()->getLoc(),
147*3c700d13SSlava Zakharin                       "failure in hlfir.assign inlining");
148*3c700d13SSlava Zakharin       signalPassFailure();
149*3c700d13SSlava Zakharin     }
150*3c700d13SSlava Zakharin   }
151*3c700d13SSlava Zakharin };
152*3c700d13SSlava Zakharin } // namespace
153