xref: /llvm-project/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1f1d13bbdSjeanPerier //===-- AssumedRankOpConversion.cpp ---------------------------------------===//
2f1d13bbdSjeanPerier //
3f1d13bbdSjeanPerier // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f1d13bbdSjeanPerier // See https://llvm.org/LICENSE.txt for license information.
5f1d13bbdSjeanPerier // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f1d13bbdSjeanPerier //
7f1d13bbdSjeanPerier //===----------------------------------------------------------------------===//
8f1d13bbdSjeanPerier 
9f1d13bbdSjeanPerier #include "flang/Common/Fortran.h"
10f1d13bbdSjeanPerier #include "flang/Lower/BuiltinModules.h"
11f1d13bbdSjeanPerier #include "flang/Optimizer/Builder/FIRBuilder.h"
12f1d13bbdSjeanPerier #include "flang/Optimizer/Builder/Runtime/Support.h"
13f1d13bbdSjeanPerier #include "flang/Optimizer/Builder/Todo.h"
14f1d13bbdSjeanPerier #include "flang/Optimizer/Dialect/FIRDialect.h"
15f1d13bbdSjeanPerier #include "flang/Optimizer/Dialect/FIROps.h"
16f1d13bbdSjeanPerier #include "flang/Optimizer/Support/TypeCode.h"
17f1d13bbdSjeanPerier #include "flang/Optimizer/Support/Utils.h"
18f1d13bbdSjeanPerier #include "flang/Optimizer/Transforms/Passes.h"
19f1d13bbdSjeanPerier #include "flang/Runtime/support.h"
20f1d13bbdSjeanPerier #include "mlir/Dialect/Func/IR/FuncOps.h"
21f1d13bbdSjeanPerier #include "mlir/Pass/Pass.h"
22f1d13bbdSjeanPerier #include "mlir/Transforms/DialectConversion.h"
23f1d13bbdSjeanPerier #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24f1d13bbdSjeanPerier 
25f1d13bbdSjeanPerier namespace fir {
26f1d13bbdSjeanPerier #define GEN_PASS_DEF_ASSUMEDRANKOPCONVERSION
27f1d13bbdSjeanPerier #include "flang/Optimizer/Transforms/Passes.h.inc"
28f1d13bbdSjeanPerier } // namespace fir
29f1d13bbdSjeanPerier 
30f1d13bbdSjeanPerier using namespace fir;
31f1d13bbdSjeanPerier using namespace mlir;
32f1d13bbdSjeanPerier 
33f1d13bbdSjeanPerier namespace {
34f1d13bbdSjeanPerier 
35f1d13bbdSjeanPerier static int getCFIAttribute(mlir::Type boxType) {
36f1d13bbdSjeanPerier   if (fir::isAllocatableType(boxType))
37f1d13bbdSjeanPerier     return CFI_attribute_allocatable;
38f1d13bbdSjeanPerier   if (fir::isPointerType(boxType))
39f1d13bbdSjeanPerier     return CFI_attribute_pointer;
40f1d13bbdSjeanPerier   return CFI_attribute_other;
41f1d13bbdSjeanPerier }
42f1d13bbdSjeanPerier 
43f1d13bbdSjeanPerier static Fortran::runtime::LowerBoundModifier
44f1d13bbdSjeanPerier getLowerBoundModifier(fir::LowerBoundModifierAttribute modifier) {
45f1d13bbdSjeanPerier   switch (modifier) {
46f1d13bbdSjeanPerier   case fir::LowerBoundModifierAttribute::Preserve:
47f1d13bbdSjeanPerier     return Fortran::runtime::LowerBoundModifier::Preserve;
48f1d13bbdSjeanPerier   case fir::LowerBoundModifierAttribute::SetToOnes:
49f1d13bbdSjeanPerier     return Fortran::runtime::LowerBoundModifier::SetToOnes;
50f1d13bbdSjeanPerier   case fir::LowerBoundModifierAttribute::SetToZeroes:
51f1d13bbdSjeanPerier     return Fortran::runtime::LowerBoundModifier::SetToZeroes;
52f1d13bbdSjeanPerier   }
53f1d13bbdSjeanPerier   llvm_unreachable("bad modifier code");
54f1d13bbdSjeanPerier }
55f1d13bbdSjeanPerier 
56f1d13bbdSjeanPerier class ReboxAssumedRankConv
57f1d13bbdSjeanPerier     : public mlir::OpRewritePattern<fir::ReboxAssumedRankOp> {
58f1d13bbdSjeanPerier public:
59f1d13bbdSjeanPerier   using OpRewritePattern::OpRewritePattern;
60f1d13bbdSjeanPerier 
61f1d13bbdSjeanPerier   ReboxAssumedRankConv(mlir::MLIRContext *context,
62f1d13bbdSjeanPerier                        mlir::SymbolTable *symbolTable, fir::KindMapping kindMap)
63f1d13bbdSjeanPerier       : mlir::OpRewritePattern<fir::ReboxAssumedRankOp>(context),
64f1d13bbdSjeanPerier         symbolTable{symbolTable}, kindMap{kindMap} {};
65f1d13bbdSjeanPerier 
66db791b27SRamkumar Ramachandra   llvm::LogicalResult
67f1d13bbdSjeanPerier   matchAndRewrite(fir::ReboxAssumedRankOp rebox,
68f1d13bbdSjeanPerier                   mlir::PatternRewriter &rewriter) const override {
69f1d13bbdSjeanPerier     fir::FirOpBuilder builder{rewriter, kindMap, symbolTable};
70f1d13bbdSjeanPerier     mlir::Location loc = rebox.getLoc();
71f1d13bbdSjeanPerier     auto newBoxType = mlir::cast<fir::BaseBoxType>(rebox.getType());
72f1d13bbdSjeanPerier     mlir::Type newMaxRankBoxType =
73f1d13bbdSjeanPerier         newBoxType.getBoxTypeWithNewShape(Fortran::common::maxRank);
74f1d13bbdSjeanPerier     // CopyAndUpdateDescriptor FIR interface requires loading
75f1d13bbdSjeanPerier     // !fir.ref<fir.box> input which is expensive with assumed-rank. It could
76f1d13bbdSjeanPerier     // be best to add an entry point that takes a non "const" from to cover
77f1d13bbdSjeanPerier     // this case, but it would be good to indicate to LLVM that from does not
78f1d13bbdSjeanPerier     // get modified.
79f1d13bbdSjeanPerier     if (fir::isBoxAddress(rebox.getBox().getType()))
80f1d13bbdSjeanPerier       TODO(loc, "fir.rebox_assumed_rank codegen with fir.ref<fir.box<>> input");
81f1d13bbdSjeanPerier     mlir::Value tempDesc = builder.createTemporary(loc, newMaxRankBoxType);
82f1d13bbdSjeanPerier     mlir::Value newDtype;
83f1d13bbdSjeanPerier     mlir::Type newEleType = newBoxType.unwrapInnerType();
84f1d13bbdSjeanPerier     auto oldBoxType = mlir::cast<fir::BaseBoxType>(
85f1d13bbdSjeanPerier         fir::unwrapRefType(rebox.getBox().getType()));
86f1d13bbdSjeanPerier     auto newDerivedType = mlir::dyn_cast<fir::RecordType>(newEleType);
872b66d283SjeanPerier     if (newDerivedType && !fir::isPolymorphicType(newBoxType) &&
882b66d283SjeanPerier         (fir::isPolymorphicType(oldBoxType) ||
892b66d283SjeanPerier          (newEleType != oldBoxType.unwrapInnerType())) &&
90f1d13bbdSjeanPerier         !fir::isPolymorphicType(newBoxType)) {
91f1d13bbdSjeanPerier       newDtype = builder.create<fir::TypeDescOp>(
92f1d13bbdSjeanPerier           loc, mlir::TypeAttr::get(newDerivedType));
93f1d13bbdSjeanPerier     } else {
94f1d13bbdSjeanPerier       newDtype = builder.createNullConstant(loc);
95f1d13bbdSjeanPerier     }
96f1d13bbdSjeanPerier     mlir::Value newAttribute = builder.createIntegerConstant(
97f1d13bbdSjeanPerier         loc, builder.getIntegerType(8), getCFIAttribute(newBoxType));
98f1d13bbdSjeanPerier     int lbsModifierCode =
99f1d13bbdSjeanPerier         static_cast<int>(getLowerBoundModifier(rebox.getLbsModifier()));
100f1d13bbdSjeanPerier     mlir::Value lowerBoundModifier = builder.createIntegerConstant(
101f1d13bbdSjeanPerier         loc, builder.getIntegerType(32), lbsModifierCode);
102f1d13bbdSjeanPerier     fir::runtime::genCopyAndUpdateDescriptor(builder, loc, tempDesc,
103f1d13bbdSjeanPerier                                              rebox.getBox(), newDtype,
104f1d13bbdSjeanPerier                                              newAttribute, lowerBoundModifier);
105f1d13bbdSjeanPerier 
106f1d13bbdSjeanPerier     mlir::Value descValue = builder.create<fir::LoadOp>(loc, tempDesc);
107f1d13bbdSjeanPerier     mlir::Value castDesc = builder.createConvert(loc, newBoxType, descValue);
108f1d13bbdSjeanPerier     rewriter.replaceOp(rebox, castDesc);
109f1d13bbdSjeanPerier     return mlir::success();
110f1d13bbdSjeanPerier   }
111f1d13bbdSjeanPerier 
112f1d13bbdSjeanPerier private:
113f1d13bbdSjeanPerier   mlir::SymbolTable *symbolTable = nullptr;
114f1d13bbdSjeanPerier   fir::KindMapping kindMap;
115f1d13bbdSjeanPerier };
116f1d13bbdSjeanPerier 
117539dbfcfSjeanPerier class IsAssumedSizeConv : public mlir::OpRewritePattern<fir::IsAssumedSizeOp> {
118539dbfcfSjeanPerier public:
119539dbfcfSjeanPerier   using OpRewritePattern::OpRewritePattern;
120539dbfcfSjeanPerier 
121539dbfcfSjeanPerier   IsAssumedSizeConv(mlir::MLIRContext *context, mlir::SymbolTable *symbolTable,
122539dbfcfSjeanPerier                     fir::KindMapping kindMap)
123539dbfcfSjeanPerier       : mlir::OpRewritePattern<fir::IsAssumedSizeOp>(context),
124539dbfcfSjeanPerier         symbolTable{symbolTable}, kindMap{kindMap} {};
125539dbfcfSjeanPerier 
126db791b27SRamkumar Ramachandra   llvm::LogicalResult
127539dbfcfSjeanPerier   matchAndRewrite(fir::IsAssumedSizeOp isAssumedSizeOp,
128539dbfcfSjeanPerier                   mlir::PatternRewriter &rewriter) const override {
129539dbfcfSjeanPerier     fir::FirOpBuilder builder{rewriter, kindMap, symbolTable};
130539dbfcfSjeanPerier     mlir::Location loc = isAssumedSizeOp.getLoc();
131539dbfcfSjeanPerier     mlir::Value result =
132539dbfcfSjeanPerier         fir::runtime::genIsAssumedSize(builder, loc, isAssumedSizeOp.getVal());
133539dbfcfSjeanPerier     rewriter.replaceOp(isAssumedSizeOp, result);
134539dbfcfSjeanPerier     return mlir::success();
135539dbfcfSjeanPerier   }
136539dbfcfSjeanPerier 
137539dbfcfSjeanPerier private:
138539dbfcfSjeanPerier   mlir::SymbolTable *symbolTable = nullptr;
139539dbfcfSjeanPerier   fir::KindMapping kindMap;
140539dbfcfSjeanPerier };
141539dbfcfSjeanPerier 
142f1d13bbdSjeanPerier /// Convert FIR structured control flow ops to CFG ops.
143f1d13bbdSjeanPerier class AssumedRankOpConversion
144f1d13bbdSjeanPerier     : public fir::impl::AssumedRankOpConversionBase<AssumedRankOpConversion> {
145f1d13bbdSjeanPerier public:
146f1d13bbdSjeanPerier   void runOnOperation() override {
147f1d13bbdSjeanPerier     auto *context = &getContext();
148f1d13bbdSjeanPerier     mlir::ModuleOp mod = getOperation();
149f1d13bbdSjeanPerier     mlir::SymbolTable symbolTable(mod);
150f1d13bbdSjeanPerier     fir::KindMapping kindMap = fir::getKindMapping(mod);
151f1d13bbdSjeanPerier     mlir::RewritePatternSet patterns(context);
152f1d13bbdSjeanPerier     patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
153539dbfcfSjeanPerier     patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
154f1d13bbdSjeanPerier     mlir::GreedyRewriteConfig config;
155a506279eSMehdi Amini     config.enableRegionSimplification =
156a506279eSMehdi Amini         mlir::GreedySimplifyRegionLevel::Disabled;
157*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(mod, std::move(patterns), config);
158f1d13bbdSjeanPerier   }
159f1d13bbdSjeanPerier };
160f1d13bbdSjeanPerier } // namespace
161