1*3c700d13SSlava Zakharin //===- InlineHLFIRAssign.cpp - Inline hlfir.assign ops --------------------===// 2*3c700d13SSlava Zakharin // 3*3c700d13SSlava Zakharin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*3c700d13SSlava Zakharin // See https://llvm.org/LICENSE.txt for license information. 5*3c700d13SSlava Zakharin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*3c700d13SSlava Zakharin // 7*3c700d13SSlava Zakharin //===----------------------------------------------------------------------===// 8*3c700d13SSlava Zakharin // Transform hlfir.assign array operations into loop nests performing element 9*3c700d13SSlava Zakharin // per element assignments. The inlining is done for trivial data types always, 10*3c700d13SSlava Zakharin // though, we may add performance/code-size heuristics in future. 11*3c700d13SSlava Zakharin //===----------------------------------------------------------------------===// 12*3c700d13SSlava Zakharin 13*3c700d13SSlava Zakharin #include "flang/Optimizer/Analysis/AliasAnalysis.h" 14*3c700d13SSlava Zakharin #include "flang/Optimizer/Builder/FIRBuilder.h" 15*3c700d13SSlava Zakharin #include "flang/Optimizer/Builder/HLFIRTools.h" 16*3c700d13SSlava Zakharin #include "flang/Optimizer/HLFIR/HLFIROps.h" 17*3c700d13SSlava Zakharin #include "flang/Optimizer/HLFIR/Passes.h" 18*3c700d13SSlava Zakharin #include "flang/Optimizer/OpenMP/Passes.h" 19*3c700d13SSlava Zakharin #include "mlir/IR/PatternMatch.h" 20*3c700d13SSlava Zakharin #include "mlir/Pass/Pass.h" 21*3c700d13SSlava Zakharin #include "mlir/Support/LLVM.h" 22*3c700d13SSlava Zakharin #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23*3c700d13SSlava Zakharin 24*3c700d13SSlava Zakharin namespace hlfir { 25*3c700d13SSlava Zakharin #define GEN_PASS_DEF_INLINEHLFIRASSIGN 26*3c700d13SSlava Zakharin #include "flang/Optimizer/HLFIR/Passes.h.inc" 27*3c700d13SSlava Zakharin } // namespace hlfir 28*3c700d13SSlava Zakharin 29*3c700d13SSlava Zakharin #define DEBUG_TYPE "inline-hlfir-assign" 30*3c700d13SSlava Zakharin 31*3c700d13SSlava Zakharin namespace { 32*3c700d13SSlava Zakharin /// Expand hlfir.assign of array RHS to array LHS into a loop nest 33*3c700d13SSlava Zakharin /// of element-by-element assignments: 34*3c700d13SSlava Zakharin /// hlfir.assign %4 to %5 : !fir.ref<!fir.array<3x3xf32>>, 35*3c700d13SSlava Zakharin /// !fir.ref<!fir.array<3x3xf32>> 36*3c700d13SSlava Zakharin /// into: 37*3c700d13SSlava Zakharin /// fir.do_loop %arg1 = %c1 to %c3 step %c1 unordered { 38*3c700d13SSlava Zakharin /// fir.do_loop %arg2 = %c1 to %c3 step %c1 unordered { 39*3c700d13SSlava Zakharin /// %6 = hlfir.designate %4 (%arg2, %arg1) : 40*3c700d13SSlava Zakharin /// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32> 41*3c700d13SSlava Zakharin /// %7 = fir.load %6 : !fir.ref<f32> 42*3c700d13SSlava Zakharin /// %8 = hlfir.designate %5 (%arg2, %arg1) : 43*3c700d13SSlava Zakharin /// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32> 44*3c700d13SSlava Zakharin /// hlfir.assign %7 to %8 : f32, !fir.ref<f32> 45*3c700d13SSlava Zakharin /// } 46*3c700d13SSlava Zakharin /// } 47*3c700d13SSlava Zakharin /// 48*3c700d13SSlava Zakharin /// The transformation is correct only when LHS and RHS do not alias. 49*3c700d13SSlava Zakharin /// When RHS is an array expression, then there is no aliasing. 50*3c700d13SSlava Zakharin /// This transformation does not support runtime checking for 51*3c700d13SSlava Zakharin /// non-conforming LHS/RHS arrays' shapes currently. 52*3c700d13SSlava Zakharin class InlineHLFIRAssignConversion 53*3c700d13SSlava Zakharin : public mlir::OpRewritePattern<hlfir::AssignOp> { 54*3c700d13SSlava Zakharin public: 55*3c700d13SSlava Zakharin using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern; 56*3c700d13SSlava Zakharin 57*3c700d13SSlava Zakharin llvm::LogicalResult 58*3c700d13SSlava Zakharin matchAndRewrite(hlfir::AssignOp assign, 59*3c700d13SSlava Zakharin mlir::PatternRewriter &rewriter) const override { 60*3c700d13SSlava Zakharin if (assign.isAllocatableAssignment()) 61*3c700d13SSlava Zakharin return rewriter.notifyMatchFailure(assign, 62*3c700d13SSlava Zakharin "AssignOp may imply allocation"); 63*3c700d13SSlava Zakharin 64*3c700d13SSlava Zakharin hlfir::Entity rhs{assign.getRhs()}; 65*3c700d13SSlava Zakharin 66*3c700d13SSlava Zakharin if (!rhs.isArray()) 67*3c700d13SSlava Zakharin return rewriter.notifyMatchFailure(assign, 68*3c700d13SSlava Zakharin "AssignOp's RHS is not an array"); 69*3c700d13SSlava Zakharin 70*3c700d13SSlava Zakharin mlir::Type rhsEleTy = rhs.getFortranElementType(); 71*3c700d13SSlava Zakharin if (!fir::isa_trivial(rhsEleTy)) 72*3c700d13SSlava Zakharin return rewriter.notifyMatchFailure( 73*3c700d13SSlava Zakharin assign, "AssignOp's RHS data type is not trivial"); 74*3c700d13SSlava Zakharin 75*3c700d13SSlava Zakharin hlfir::Entity lhs{assign.getLhs()}; 76*3c700d13SSlava Zakharin if (!lhs.isArray()) 77*3c700d13SSlava Zakharin return rewriter.notifyMatchFailure(assign, 78*3c700d13SSlava Zakharin "AssignOp's LHS is not an array"); 79*3c700d13SSlava Zakharin 80*3c700d13SSlava Zakharin mlir::Type lhsEleTy = lhs.getFortranElementType(); 81*3c700d13SSlava Zakharin if (!fir::isa_trivial(lhsEleTy)) 82*3c700d13SSlava Zakharin return rewriter.notifyMatchFailure( 83*3c700d13SSlava Zakharin assign, "AssignOp's LHS data type is not trivial"); 84*3c700d13SSlava Zakharin 85*3c700d13SSlava Zakharin if (lhsEleTy != rhsEleTy) 86*3c700d13SSlava Zakharin return rewriter.notifyMatchFailure(assign, 87*3c700d13SSlava Zakharin "RHS/LHS element types mismatch"); 88*3c700d13SSlava Zakharin 89*3c700d13SSlava Zakharin if (!mlir::isa<hlfir::ExprType>(rhs.getType())) { 90*3c700d13SSlava Zakharin // If RHS is not an hlfir.expr, then we should prove that 91*3c700d13SSlava Zakharin // LHS and RHS do not alias. 92*3c700d13SSlava Zakharin // TODO: if they may alias, we can insert hlfir.as_expr for RHS, 93*3c700d13SSlava Zakharin // and proceed with the inlining. 94*3c700d13SSlava Zakharin fir::AliasAnalysis aliasAnalysis; 95*3c700d13SSlava Zakharin mlir::AliasResult aliasRes = aliasAnalysis.alias(lhs, rhs); 96*3c700d13SSlava Zakharin // TODO: use areIdenticalOrDisjointSlices() from 97*3c700d13SSlava Zakharin // OptimizedBufferization.cpp to check if we can still do the expansion. 98*3c700d13SSlava Zakharin if (!aliasRes.isNo()) { 99*3c700d13SSlava Zakharin LLVM_DEBUG(llvm::dbgs() << "InlineHLFIRAssign:\n" 100*3c700d13SSlava Zakharin << "\tLHS: " << lhs << "\n" 101*3c700d13SSlava Zakharin << "\tRHS: " << rhs << "\n" 102*3c700d13SSlava Zakharin << "\tALIAS: " << aliasRes << "\n"); 103*3c700d13SSlava Zakharin return rewriter.notifyMatchFailure(assign, "RHS/LHS may alias"); 104*3c700d13SSlava Zakharin } 105*3c700d13SSlava Zakharin } 106*3c700d13SSlava Zakharin 107*3c700d13SSlava Zakharin mlir::Location loc = assign->getLoc(); 108*3c700d13SSlava Zakharin fir::FirOpBuilder builder(rewriter, assign.getOperation()); 109*3c700d13SSlava Zakharin builder.setInsertionPoint(assign); 110*3c700d13SSlava Zakharin rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs); 111*3c700d13SSlava Zakharin lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs); 112*3c700d13SSlava Zakharin mlir::Value shape = hlfir::genShape(loc, builder, lhs); 113*3c700d13SSlava Zakharin llvm::SmallVector<mlir::Value> extents = 114*3c700d13SSlava Zakharin hlfir::getIndexExtents(loc, builder, shape); 115*3c700d13SSlava Zakharin hlfir::LoopNest loopNest = 116*3c700d13SSlava Zakharin hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, 117*3c700d13SSlava Zakharin flangomp::shouldUseWorkshareLowering(assign)); 118*3c700d13SSlava Zakharin builder.setInsertionPointToStart(loopNest.body); 119*3c700d13SSlava Zakharin auto rhsArrayElement = 120*3c700d13SSlava Zakharin hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices); 121*3c700d13SSlava Zakharin rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement); 122*3c700d13SSlava Zakharin auto lhsArrayElement = 123*3c700d13SSlava Zakharin hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); 124*3c700d13SSlava Zakharin builder.create<hlfir::AssignOp>(loc, rhsArrayElement, lhsArrayElement); 125*3c700d13SSlava Zakharin rewriter.eraseOp(assign); 126*3c700d13SSlava Zakharin return mlir::success(); 127*3c700d13SSlava Zakharin } 128*3c700d13SSlava Zakharin }; 129*3c700d13SSlava Zakharin 130*3c700d13SSlava Zakharin class InlineHLFIRAssignPass 131*3c700d13SSlava Zakharin : public hlfir::impl::InlineHLFIRAssignBase<InlineHLFIRAssignPass> { 132*3c700d13SSlava Zakharin public: 133*3c700d13SSlava Zakharin void runOnOperation() override { 134*3c700d13SSlava Zakharin mlir::MLIRContext *context = &getContext(); 135*3c700d13SSlava Zakharin 136*3c700d13SSlava Zakharin mlir::GreedyRewriteConfig config; 137*3c700d13SSlava Zakharin // Prevent the pattern driver from merging blocks. 138*3c700d13SSlava Zakharin config.enableRegionSimplification = 139*3c700d13SSlava Zakharin mlir::GreedySimplifyRegionLevel::Disabled; 140*3c700d13SSlava Zakharin 141*3c700d13SSlava Zakharin mlir::RewritePatternSet patterns(context); 142*3c700d13SSlava Zakharin patterns.insert<InlineHLFIRAssignConversion>(context); 143*3c700d13SSlava Zakharin 144*3c700d13SSlava Zakharin if (mlir::failed(mlir::applyPatternsGreedily( 145*3c700d13SSlava Zakharin getOperation(), std::move(patterns), config))) { 146*3c700d13SSlava Zakharin mlir::emitError(getOperation()->getLoc(), 147*3c700d13SSlava Zakharin "failure in hlfir.assign inlining"); 148*3c700d13SSlava Zakharin signalPassFailure(); 149*3c700d13SSlava Zakharin } 150*3c700d13SSlava Zakharin } 151*3c700d13SSlava Zakharin }; 152*3c700d13SSlava Zakharin } // namespace 153