xref: /llvm-project/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===-- AssumedRankOpConversion.cpp ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Common/Fortran.h"
10 #include "flang/Lower/BuiltinModules.h"
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/Runtime/Support.h"
13 #include "flang/Optimizer/Builder/Todo.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Dialect/FIROps.h"
16 #include "flang/Optimizer/Support/TypeCode.h"
17 #include "flang/Optimizer/Support/Utils.h"
18 #include "flang/Optimizer/Transforms/Passes.h"
19 #include "flang/Runtime/support.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 namespace fir {
26 #define GEN_PASS_DEF_ASSUMEDRANKOPCONVERSION
27 #include "flang/Optimizer/Transforms/Passes.h.inc"
28 } // namespace fir
29 
30 using namespace fir;
31 using namespace mlir;
32 
33 namespace {
34 
35 static int getCFIAttribute(mlir::Type boxType) {
36   if (fir::isAllocatableType(boxType))
37     return CFI_attribute_allocatable;
38   if (fir::isPointerType(boxType))
39     return CFI_attribute_pointer;
40   return CFI_attribute_other;
41 }
42 
43 static Fortran::runtime::LowerBoundModifier
44 getLowerBoundModifier(fir::LowerBoundModifierAttribute modifier) {
45   switch (modifier) {
46   case fir::LowerBoundModifierAttribute::Preserve:
47     return Fortran::runtime::LowerBoundModifier::Preserve;
48   case fir::LowerBoundModifierAttribute::SetToOnes:
49     return Fortran::runtime::LowerBoundModifier::SetToOnes;
50   case fir::LowerBoundModifierAttribute::SetToZeroes:
51     return Fortran::runtime::LowerBoundModifier::SetToZeroes;
52   }
53   llvm_unreachable("bad modifier code");
54 }
55 
56 class ReboxAssumedRankConv
57     : public mlir::OpRewritePattern<fir::ReboxAssumedRankOp> {
58 public:
59   using OpRewritePattern::OpRewritePattern;
60 
61   ReboxAssumedRankConv(mlir::MLIRContext *context,
62                        mlir::SymbolTable *symbolTable, fir::KindMapping kindMap)
63       : mlir::OpRewritePattern<fir::ReboxAssumedRankOp>(context),
64         symbolTable{symbolTable}, kindMap{kindMap} {};
65 
66   llvm::LogicalResult
67   matchAndRewrite(fir::ReboxAssumedRankOp rebox,
68                   mlir::PatternRewriter &rewriter) const override {
69     fir::FirOpBuilder builder{rewriter, kindMap, symbolTable};
70     mlir::Location loc = rebox.getLoc();
71     auto newBoxType = mlir::cast<fir::BaseBoxType>(rebox.getType());
72     mlir::Type newMaxRankBoxType =
73         newBoxType.getBoxTypeWithNewShape(Fortran::common::maxRank);
74     // CopyAndUpdateDescriptor FIR interface requires loading
75     // !fir.ref<fir.box> input which is expensive with assumed-rank. It could
76     // be best to add an entry point that takes a non "const" from to cover
77     // this case, but it would be good to indicate to LLVM that from does not
78     // get modified.
79     if (fir::isBoxAddress(rebox.getBox().getType()))
80       TODO(loc, "fir.rebox_assumed_rank codegen with fir.ref<fir.box<>> input");
81     mlir::Value tempDesc = builder.createTemporary(loc, newMaxRankBoxType);
82     mlir::Value newDtype;
83     mlir::Type newEleType = newBoxType.unwrapInnerType();
84     auto oldBoxType = mlir::cast<fir::BaseBoxType>(
85         fir::unwrapRefType(rebox.getBox().getType()));
86     auto newDerivedType = mlir::dyn_cast<fir::RecordType>(newEleType);
87     if (newDerivedType && !fir::isPolymorphicType(newBoxType) &&
88         (fir::isPolymorphicType(oldBoxType) ||
89          (newEleType != oldBoxType.unwrapInnerType())) &&
90         !fir::isPolymorphicType(newBoxType)) {
91       newDtype = builder.create<fir::TypeDescOp>(
92           loc, mlir::TypeAttr::get(newDerivedType));
93     } else {
94       newDtype = builder.createNullConstant(loc);
95     }
96     mlir::Value newAttribute = builder.createIntegerConstant(
97         loc, builder.getIntegerType(8), getCFIAttribute(newBoxType));
98     int lbsModifierCode =
99         static_cast<int>(getLowerBoundModifier(rebox.getLbsModifier()));
100     mlir::Value lowerBoundModifier = builder.createIntegerConstant(
101         loc, builder.getIntegerType(32), lbsModifierCode);
102     fir::runtime::genCopyAndUpdateDescriptor(builder, loc, tempDesc,
103                                              rebox.getBox(), newDtype,
104                                              newAttribute, lowerBoundModifier);
105 
106     mlir::Value descValue = builder.create<fir::LoadOp>(loc, tempDesc);
107     mlir::Value castDesc = builder.createConvert(loc, newBoxType, descValue);
108     rewriter.replaceOp(rebox, castDesc);
109     return mlir::success();
110   }
111 
112 private:
113   mlir::SymbolTable *symbolTable = nullptr;
114   fir::KindMapping kindMap;
115 };
116 
117 class IsAssumedSizeConv : public mlir::OpRewritePattern<fir::IsAssumedSizeOp> {
118 public:
119   using OpRewritePattern::OpRewritePattern;
120 
121   IsAssumedSizeConv(mlir::MLIRContext *context, mlir::SymbolTable *symbolTable,
122                     fir::KindMapping kindMap)
123       : mlir::OpRewritePattern<fir::IsAssumedSizeOp>(context),
124         symbolTable{symbolTable}, kindMap{kindMap} {};
125 
126   llvm::LogicalResult
127   matchAndRewrite(fir::IsAssumedSizeOp isAssumedSizeOp,
128                   mlir::PatternRewriter &rewriter) const override {
129     fir::FirOpBuilder builder{rewriter, kindMap, symbolTable};
130     mlir::Location loc = isAssumedSizeOp.getLoc();
131     mlir::Value result =
132         fir::runtime::genIsAssumedSize(builder, loc, isAssumedSizeOp.getVal());
133     rewriter.replaceOp(isAssumedSizeOp, result);
134     return mlir::success();
135   }
136 
137 private:
138   mlir::SymbolTable *symbolTable = nullptr;
139   fir::KindMapping kindMap;
140 };
141 
142 /// Convert FIR structured control flow ops to CFG ops.
143 class AssumedRankOpConversion
144     : public fir::impl::AssumedRankOpConversionBase<AssumedRankOpConversion> {
145 public:
146   void runOnOperation() override {
147     auto *context = &getContext();
148     mlir::ModuleOp mod = getOperation();
149     mlir::SymbolTable symbolTable(mod);
150     fir::KindMapping kindMap = fir::getKindMapping(mod);
151     mlir::RewritePatternSet patterns(context);
152     patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
153     patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
154     mlir::GreedyRewriteConfig config;
155     config.enableRegionSimplification =
156         mlir::GreedySimplifyRegionLevel::Disabled;
157     (void)applyPatternsGreedily(mod, std::move(patterns), config);
158   }
159 };
160 } // namespace
161