1 //===-- AssumedRankOpConversion.cpp ---------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Common/Fortran.h" 10 #include "flang/Lower/BuiltinModules.h" 11 #include "flang/Optimizer/Builder/FIRBuilder.h" 12 #include "flang/Optimizer/Builder/Runtime/Support.h" 13 #include "flang/Optimizer/Builder/Todo.h" 14 #include "flang/Optimizer/Dialect/FIRDialect.h" 15 #include "flang/Optimizer/Dialect/FIROps.h" 16 #include "flang/Optimizer/Support/TypeCode.h" 17 #include "flang/Optimizer/Support/Utils.h" 18 #include "flang/Optimizer/Transforms/Passes.h" 19 #include "flang/Runtime/support.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 namespace fir { 26 #define GEN_PASS_DEF_ASSUMEDRANKOPCONVERSION 27 #include "flang/Optimizer/Transforms/Passes.h.inc" 28 } // namespace fir 29 30 using namespace fir; 31 using namespace mlir; 32 33 namespace { 34 35 static int getCFIAttribute(mlir::Type boxType) { 36 if (fir::isAllocatableType(boxType)) 37 return CFI_attribute_allocatable; 38 if (fir::isPointerType(boxType)) 39 return CFI_attribute_pointer; 40 return CFI_attribute_other; 41 } 42 43 static Fortran::runtime::LowerBoundModifier 44 getLowerBoundModifier(fir::LowerBoundModifierAttribute modifier) { 45 switch (modifier) { 46 case fir::LowerBoundModifierAttribute::Preserve: 47 return Fortran::runtime::LowerBoundModifier::Preserve; 48 case fir::LowerBoundModifierAttribute::SetToOnes: 49 return Fortran::runtime::LowerBoundModifier::SetToOnes; 50 case fir::LowerBoundModifierAttribute::SetToZeroes: 51 return Fortran::runtime::LowerBoundModifier::SetToZeroes; 52 } 53 llvm_unreachable("bad modifier code"); 54 } 55 56 class ReboxAssumedRankConv 57 : public mlir::OpRewritePattern<fir::ReboxAssumedRankOp> { 58 public: 59 using OpRewritePattern::OpRewritePattern; 60 61 ReboxAssumedRankConv(mlir::MLIRContext *context, 62 mlir::SymbolTable *symbolTable, fir::KindMapping kindMap) 63 : mlir::OpRewritePattern<fir::ReboxAssumedRankOp>(context), 64 symbolTable{symbolTable}, kindMap{kindMap} {}; 65 66 llvm::LogicalResult 67 matchAndRewrite(fir::ReboxAssumedRankOp rebox, 68 mlir::PatternRewriter &rewriter) const override { 69 fir::FirOpBuilder builder{rewriter, kindMap, symbolTable}; 70 mlir::Location loc = rebox.getLoc(); 71 auto newBoxType = mlir::cast<fir::BaseBoxType>(rebox.getType()); 72 mlir::Type newMaxRankBoxType = 73 newBoxType.getBoxTypeWithNewShape(Fortran::common::maxRank); 74 // CopyAndUpdateDescriptor FIR interface requires loading 75 // !fir.ref<fir.box> input which is expensive with assumed-rank. It could 76 // be best to add an entry point that takes a non "const" from to cover 77 // this case, but it would be good to indicate to LLVM that from does not 78 // get modified. 79 if (fir::isBoxAddress(rebox.getBox().getType())) 80 TODO(loc, "fir.rebox_assumed_rank codegen with fir.ref<fir.box<>> input"); 81 mlir::Value tempDesc = builder.createTemporary(loc, newMaxRankBoxType); 82 mlir::Value newDtype; 83 mlir::Type newEleType = newBoxType.unwrapInnerType(); 84 auto oldBoxType = mlir::cast<fir::BaseBoxType>( 85 fir::unwrapRefType(rebox.getBox().getType())); 86 auto newDerivedType = mlir::dyn_cast<fir::RecordType>(newEleType); 87 if (newDerivedType && !fir::isPolymorphicType(newBoxType) && 88 (fir::isPolymorphicType(oldBoxType) || 89 (newEleType != oldBoxType.unwrapInnerType())) && 90 !fir::isPolymorphicType(newBoxType)) { 91 newDtype = builder.create<fir::TypeDescOp>( 92 loc, mlir::TypeAttr::get(newDerivedType)); 93 } else { 94 newDtype = builder.createNullConstant(loc); 95 } 96 mlir::Value newAttribute = builder.createIntegerConstant( 97 loc, builder.getIntegerType(8), getCFIAttribute(newBoxType)); 98 int lbsModifierCode = 99 static_cast<int>(getLowerBoundModifier(rebox.getLbsModifier())); 100 mlir::Value lowerBoundModifier = builder.createIntegerConstant( 101 loc, builder.getIntegerType(32), lbsModifierCode); 102 fir::runtime::genCopyAndUpdateDescriptor(builder, loc, tempDesc, 103 rebox.getBox(), newDtype, 104 newAttribute, lowerBoundModifier); 105 106 mlir::Value descValue = builder.create<fir::LoadOp>(loc, tempDesc); 107 mlir::Value castDesc = builder.createConvert(loc, newBoxType, descValue); 108 rewriter.replaceOp(rebox, castDesc); 109 return mlir::success(); 110 } 111 112 private: 113 mlir::SymbolTable *symbolTable = nullptr; 114 fir::KindMapping kindMap; 115 }; 116 117 class IsAssumedSizeConv : public mlir::OpRewritePattern<fir::IsAssumedSizeOp> { 118 public: 119 using OpRewritePattern::OpRewritePattern; 120 121 IsAssumedSizeConv(mlir::MLIRContext *context, mlir::SymbolTable *symbolTable, 122 fir::KindMapping kindMap) 123 : mlir::OpRewritePattern<fir::IsAssumedSizeOp>(context), 124 symbolTable{symbolTable}, kindMap{kindMap} {}; 125 126 llvm::LogicalResult 127 matchAndRewrite(fir::IsAssumedSizeOp isAssumedSizeOp, 128 mlir::PatternRewriter &rewriter) const override { 129 fir::FirOpBuilder builder{rewriter, kindMap, symbolTable}; 130 mlir::Location loc = isAssumedSizeOp.getLoc(); 131 mlir::Value result = 132 fir::runtime::genIsAssumedSize(builder, loc, isAssumedSizeOp.getVal()); 133 rewriter.replaceOp(isAssumedSizeOp, result); 134 return mlir::success(); 135 } 136 137 private: 138 mlir::SymbolTable *symbolTable = nullptr; 139 fir::KindMapping kindMap; 140 }; 141 142 /// Convert FIR structured control flow ops to CFG ops. 143 class AssumedRankOpConversion 144 : public fir::impl::AssumedRankOpConversionBase<AssumedRankOpConversion> { 145 public: 146 void runOnOperation() override { 147 auto *context = &getContext(); 148 mlir::ModuleOp mod = getOperation(); 149 mlir::SymbolTable symbolTable(mod); 150 fir::KindMapping kindMap = fir::getKindMapping(mod); 151 mlir::RewritePatternSet patterns(context); 152 patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap); 153 patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap); 154 mlir::GreedyRewriteConfig config; 155 config.enableRegionSimplification = 156 mlir::GreedySimplifyRegionLevel::Disabled; 157 (void)applyPatternsGreedily(mod, std::move(patterns), config); 158 } 159 }; 160 } // namespace 161