1 //===-- FIROpPatterns.h -- FIR operation conversion patterns ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef FORTRAN_OPTIMIZER_CODEGEN_FIROPPATTERNS_H 10 #define FORTRAN_OPTIMIZER_CODEGEN_FIROPPATTERNS_H 11 12 #include "flang/Optimizer/CodeGen/TypeConverter.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 16 namespace fir { 17 18 struct FIRToLLVMPassOptions; 19 20 static constexpr unsigned defaultAddressSpace = 0u; 21 22 class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern { 23 public: 24 ConvertFIRToLLVMPattern(llvm::StringRef rootOpName, 25 mlir::MLIRContext *context, 26 const fir::LLVMTypeConverter &typeConverter, 27 const fir::FIRToLLVMPassOptions &options, 28 mlir::PatternBenefit benefit = 1); 29 30 protected: 31 mlir::Type convertType(mlir::Type ty) const { 32 return lowerTy().convertType(ty); 33 } 34 35 // Convert FIR type to LLVM without turning fir.box<T> into memory 36 // reference. 37 mlir::Type convertObjectType(mlir::Type firType) const; 38 39 mlir::LLVM::ConstantOp 40 genI32Constant(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, 41 int value) const; 42 43 mlir::LLVM::ConstantOp 44 genConstantOffset(mlir::Location loc, 45 mlir::ConversionPatternRewriter &rewriter, 46 int offset) const; 47 48 /// Perform an extension or truncation as needed on an integer value. Lowering 49 /// to the specific target may involve some sign-extending or truncation of 50 /// values, particularly to fit them from abstract box types to the 51 /// appropriate reified structures. 52 mlir::Value integerCast(mlir::Location loc, 53 mlir::ConversionPatternRewriter &rewriter, 54 mlir::Type ty, mlir::Value val, 55 bool fold = false) const; 56 57 struct TypePair { 58 mlir::Type fir; 59 mlir::Type llvm; 60 }; 61 62 TypePair getBoxTypePair(mlir::Type firBoxTy) const; 63 64 /// Construct code sequence to extract the specific value from a `fir.box`. 65 mlir::Value getValueFromBox(mlir::Location loc, TypePair boxTy, 66 mlir::Value box, mlir::Type resultTy, 67 mlir::ConversionPatternRewriter &rewriter, 68 int boxValue) const; 69 70 /// Method to construct code sequence to get the triple for dimension `dim` 71 /// from a box. 72 llvm::SmallVector<mlir::Value, 3> 73 getDimsFromBox(mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, 74 TypePair boxTy, mlir::Value box, mlir::Value dim, 75 mlir::ConversionPatternRewriter &rewriter) const; 76 77 llvm::SmallVector<mlir::Value, 3> 78 getDimsFromBox(mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, 79 TypePair boxTy, mlir::Value box, int dim, 80 mlir::ConversionPatternRewriter &rewriter) const; 81 82 mlir::Value 83 loadDimFieldFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box, 84 mlir::Value dim, int off, mlir::Type ty, 85 mlir::ConversionPatternRewriter &rewriter) const; 86 87 mlir::Value 88 getDimFieldFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box, 89 int dim, int off, mlir::Type ty, 90 mlir::ConversionPatternRewriter &rewriter) const; 91 92 mlir::Value getStrideFromBox(mlir::Location loc, TypePair boxTy, 93 mlir::Value box, unsigned dim, 94 mlir::ConversionPatternRewriter &rewriter) const; 95 96 /// Read base address from a fir.box. Returned address has type ty. 97 mlir::Value 98 getBaseAddrFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box, 99 mlir::ConversionPatternRewriter &rewriter) const; 100 101 mlir::Value 102 getElementSizeFromBox(mlir::Location loc, mlir::Type resultTy, TypePair boxTy, 103 mlir::Value box, 104 mlir::ConversionPatternRewriter &rewriter) const; 105 106 mlir::Value getRankFromBox(mlir::Location loc, TypePair boxTy, 107 mlir::Value box, 108 mlir::ConversionPatternRewriter &rewriter) const; 109 110 mlir::Value getExtraFromBox(mlir::Location loc, TypePair boxTy, 111 mlir::Value box, 112 mlir::ConversionPatternRewriter &rewriter) const; 113 114 // Get the element type given an LLVM type that is of the form 115 // (array|struct|vector)+ and the provided indexes. 116 mlir::Type getBoxEleTy(mlir::Type type, 117 llvm::ArrayRef<std::int64_t> indexes) const; 118 119 // Return LLVM type of the object described by a fir.box of \p boxType. 120 mlir::Type getLlvmObjectTypeFromBoxType(mlir::Type boxType) const; 121 122 /// Read the address of the type descriptor from a box. 123 mlir::Value 124 loadTypeDescAddress(mlir::Location loc, TypePair boxTy, mlir::Value box, 125 mlir::ConversionPatternRewriter &rewriter) const; 126 127 // Load the attribute from the \p box and perform a check against \p maskValue 128 // The final comparison is implemented as `(attribute & maskValue) != 0`. 129 mlir::Value genBoxAttributeCheck(mlir::Location loc, TypePair boxTy, 130 mlir::Value box, 131 mlir::ConversionPatternRewriter &rewriter, 132 unsigned maskValue) const; 133 134 /// Compute the descriptor size in bytes. The result is not guaranteed to be a 135 /// compile time constant if the box is for an assumed rank, in which case the 136 /// box rank will be read. 137 mlir::Value computeBoxSize(mlir::Location, TypePair boxTy, mlir::Value box, 138 mlir::ConversionPatternRewriter &rewriter) const; 139 140 template <typename... ARGS> 141 mlir::LLVM::GEPOp genGEP(mlir::Location loc, mlir::Type ty, 142 mlir::ConversionPatternRewriter &rewriter, 143 mlir::Value base, ARGS... args) const { 144 llvm::SmallVector<mlir::LLVM::GEPArg> cv = {args...}; 145 auto llvmPtrTy = 146 mlir::LLVM::LLVMPointerType::get(ty.getContext(), /*addressSpace=*/0); 147 return rewriter.create<mlir::LLVM::GEPOp>(loc, llvmPtrTy, ty, base, cv); 148 } 149 150 // Find the Block in which the alloca should be inserted. 151 // The order to recursively find the proper block: 152 // 1. An OpenMP Op that will be outlined. 153 // 2. An OpenMP or OpenACC Op with one or more regions holding executable 154 // code. 155 // 3. A LLVMFuncOp 156 // 4. The first ancestor that is one of the above. 157 mlir::Block *getBlockForAllocaInsert(mlir::Operation *op, 158 mlir::Region *parentRegion) const; 159 160 // Generate an alloca of size 1 for an object of type \p llvmObjectTy in the 161 // allocation address space provided for the architecture in the DataLayout 162 // specification. If the address space is different from the devices 163 // program address space we perform a cast. In the case of most architectures 164 // the program and allocation address space will be the default of 0 and no 165 // cast will be emitted. 166 mlir::Value 167 genAllocaAndAddrCastWithType(mlir::Location loc, mlir::Type llvmObjectTy, 168 unsigned alignment, 169 mlir::ConversionPatternRewriter &rewriter) const; 170 171 const fir::LLVMTypeConverter &lowerTy() const { 172 return *static_cast<const fir::LLVMTypeConverter *>( 173 this->getTypeConverter()); 174 } 175 176 void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op, 177 mlir::Type baseFIRType, mlir::Type accessFIRType, 178 mlir::LLVM::GEPOp gep) const { 179 lowerTy().attachTBAATag(op, baseFIRType, accessFIRType, gep); 180 } 181 182 unsigned 183 getAllocaAddressSpace(mlir::ConversionPatternRewriter &rewriter) const; 184 185 unsigned 186 getProgramAddressSpace(mlir::ConversionPatternRewriter &rewriter) const; 187 188 const fir::FIRToLLVMPassOptions &options; 189 190 using ConvertToLLVMPattern::match; 191 using ConvertToLLVMPattern::matchAndRewrite; 192 }; 193 194 template <typename SourceOp> 195 class FIROpConversion : public ConvertFIRToLLVMPattern { 196 public: 197 using OpAdaptor = typename SourceOp::Adaptor; 198 using OneToNOpAdaptor = typename SourceOp::template GenericAdaptor< 199 mlir::ArrayRef<mlir::ValueRange>>; 200 201 explicit FIROpConversion(const LLVMTypeConverter &typeConverter, 202 const fir::FIRToLLVMPassOptions &options, 203 mlir::PatternBenefit benefit = 1) 204 : ConvertFIRToLLVMPattern(SourceOp::getOperationName(), 205 &typeConverter.getContext(), typeConverter, 206 options, benefit) {} 207 208 /// Wrappers around the RewritePattern methods that pass the derived op type. 209 void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands, 210 mlir::ConversionPatternRewriter &rewriter) const final { 211 rewrite(mlir::cast<SourceOp>(op), 212 OpAdaptor(operands, mlir::cast<SourceOp>(op)), rewriter); 213 } 214 void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::ValueRange> operands, 215 mlir::ConversionPatternRewriter &rewriter) const final { 216 auto sourceOp = llvm::cast<SourceOp>(op); 217 rewrite(llvm::cast<SourceOp>(op), OneToNOpAdaptor(operands, sourceOp), 218 rewriter); 219 } 220 llvm::LogicalResult match(mlir::Operation *op) const final { 221 return match(mlir::cast<SourceOp>(op)); 222 } 223 llvm::LogicalResult 224 matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands, 225 mlir::ConversionPatternRewriter &rewriter) const final { 226 return matchAndRewrite(mlir::cast<SourceOp>(op), 227 OpAdaptor(operands, mlir::cast<SourceOp>(op)), 228 rewriter); 229 } 230 llvm::LogicalResult 231 matchAndRewrite(mlir::Operation *op, 232 mlir::ArrayRef<mlir::ValueRange> operands, 233 mlir::ConversionPatternRewriter &rewriter) const final { 234 auto sourceOp = mlir::cast<SourceOp>(op); 235 return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), 236 rewriter); 237 } 238 /// Rewrite and Match methods that operate on the SourceOp type. These must be 239 /// overridden by the derived pattern class. 240 virtual llvm::LogicalResult match(SourceOp op) const { 241 llvm_unreachable("must override match or matchAndRewrite"); 242 } 243 virtual void rewrite(SourceOp op, OpAdaptor adaptor, 244 mlir::ConversionPatternRewriter &rewriter) const { 245 llvm_unreachable("must override rewrite or matchAndRewrite"); 246 } 247 virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, 248 mlir::ConversionPatternRewriter &rewriter) const { 249 llvm::SmallVector<mlir::Value> oneToOneOperands = 250 getOneToOneAdaptorOperands(adaptor.getOperands()); 251 rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); 252 } 253 virtual llvm::LogicalResult 254 matchAndRewrite(SourceOp op, OpAdaptor adaptor, 255 mlir::ConversionPatternRewriter &rewriter) const { 256 if (mlir::failed(match(op))) 257 return mlir::failure(); 258 rewrite(op, adaptor, rewriter); 259 return mlir::success(); 260 } 261 virtual llvm::LogicalResult 262 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, 263 mlir::ConversionPatternRewriter &rewriter) const { 264 llvm::SmallVector<mlir::Value> oneToOneOperands = 265 getOneToOneAdaptorOperands(adaptor.getOperands()); 266 return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); 267 } 268 269 private: 270 using ConvertFIRToLLVMPattern::matchAndRewrite; 271 using ConvertToLLVMPattern::match; 272 }; 273 274 /// FIR conversion pattern template 275 template <typename FromOp> 276 class FIROpAndTypeConversion : public FIROpConversion<FromOp> { 277 public: 278 using FIROpConversion<FromOp>::FIROpConversion; 279 using OpAdaptor = typename FromOp::Adaptor; 280 281 llvm::LogicalResult 282 matchAndRewrite(FromOp op, OpAdaptor adaptor, 283 mlir::ConversionPatternRewriter &rewriter) const final { 284 mlir::Type ty = this->convertType(op.getType()); 285 return doRewrite(op, ty, adaptor, rewriter); 286 } 287 288 virtual llvm::LogicalResult 289 doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor, 290 mlir::ConversionPatternRewriter &rewriter) const = 0; 291 }; 292 293 } // namespace fir 294 295 #endif // FORTRAN_OPTIMIZER_CODEGEN_FIROPPATTERNS_H 296