xref: /llvm-project/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h (revision 58389b220a9354ed6c34bdb9310a35165579c5e3)
1 //===-- FIROpPatterns.h -- FIR operation conversion patterns ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef FORTRAN_OPTIMIZER_CODEGEN_FIROPPATTERNS_H
10 #define FORTRAN_OPTIMIZER_CODEGEN_FIROPPATTERNS_H
11 
12 #include "flang/Optimizer/CodeGen/TypeConverter.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 
16 namespace fir {
17 
18 struct FIRToLLVMPassOptions;
19 
20 static constexpr unsigned defaultAddressSpace = 0u;
21 
22 class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
23 public:
24   ConvertFIRToLLVMPattern(llvm::StringRef rootOpName,
25                           mlir::MLIRContext *context,
26                           const fir::LLVMTypeConverter &typeConverter,
27                           const fir::FIRToLLVMPassOptions &options,
28                           mlir::PatternBenefit benefit = 1);
29 
30 protected:
31   mlir::Type convertType(mlir::Type ty) const {
32     return lowerTy().convertType(ty);
33   }
34 
35   // Convert FIR type to LLVM without turning fir.box<T> into memory
36   // reference.
37   mlir::Type convertObjectType(mlir::Type firType) const;
38 
39   mlir::LLVM::ConstantOp
40   genI32Constant(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
41                  int value) const;
42 
43   mlir::LLVM::ConstantOp
44   genConstantOffset(mlir::Location loc,
45                     mlir::ConversionPatternRewriter &rewriter,
46                     int offset) const;
47 
48   /// Perform an extension or truncation as needed on an integer value. Lowering
49   /// to the specific target may involve some sign-extending or truncation of
50   /// values, particularly to fit them from abstract box types to the
51   /// appropriate reified structures.
52   mlir::Value integerCast(mlir::Location loc,
53                           mlir::ConversionPatternRewriter &rewriter,
54                           mlir::Type ty, mlir::Value val,
55                           bool fold = false) const;
56 
57   struct TypePair {
58     mlir::Type fir;
59     mlir::Type llvm;
60   };
61 
62   TypePair getBoxTypePair(mlir::Type firBoxTy) const;
63 
64   /// Construct code sequence to extract the specific value from a `fir.box`.
65   mlir::Value getValueFromBox(mlir::Location loc, TypePair boxTy,
66                               mlir::Value box, mlir::Type resultTy,
67                               mlir::ConversionPatternRewriter &rewriter,
68                               int boxValue) const;
69 
70   /// Method to construct code sequence to get the triple for dimension `dim`
71   /// from a box.
72   llvm::SmallVector<mlir::Value, 3>
73   getDimsFromBox(mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys,
74                  TypePair boxTy, mlir::Value box, mlir::Value dim,
75                  mlir::ConversionPatternRewriter &rewriter) const;
76 
77   llvm::SmallVector<mlir::Value, 3>
78   getDimsFromBox(mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys,
79                  TypePair boxTy, mlir::Value box, int dim,
80                  mlir::ConversionPatternRewriter &rewriter) const;
81 
82   mlir::Value
83   loadDimFieldFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box,
84                       mlir::Value dim, int off, mlir::Type ty,
85                       mlir::ConversionPatternRewriter &rewriter) const;
86 
87   mlir::Value
88   getDimFieldFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box,
89                      int dim, int off, mlir::Type ty,
90                      mlir::ConversionPatternRewriter &rewriter) const;
91 
92   mlir::Value getStrideFromBox(mlir::Location loc, TypePair boxTy,
93                                mlir::Value box, unsigned dim,
94                                mlir::ConversionPatternRewriter &rewriter) const;
95 
96   /// Read base address from a fir.box. Returned address has type ty.
97   mlir::Value
98   getBaseAddrFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box,
99                      mlir::ConversionPatternRewriter &rewriter) const;
100 
101   mlir::Value
102   getElementSizeFromBox(mlir::Location loc, mlir::Type resultTy, TypePair boxTy,
103                         mlir::Value box,
104                         mlir::ConversionPatternRewriter &rewriter) const;
105 
106   mlir::Value getRankFromBox(mlir::Location loc, TypePair boxTy,
107                              mlir::Value box,
108                              mlir::ConversionPatternRewriter &rewriter) const;
109 
110   mlir::Value getExtraFromBox(mlir::Location loc, TypePair boxTy,
111                               mlir::Value box,
112                               mlir::ConversionPatternRewriter &rewriter) const;
113 
114   // Get the element type given an LLVM type that is of the form
115   // (array|struct|vector)+ and the provided indexes.
116   mlir::Type getBoxEleTy(mlir::Type type,
117                          llvm::ArrayRef<std::int64_t> indexes) const;
118 
119   // Return LLVM type of the object described by a fir.box of \p boxType.
120   mlir::Type getLlvmObjectTypeFromBoxType(mlir::Type boxType) const;
121 
122   /// Read the address of the type descriptor from a box.
123   mlir::Value
124   loadTypeDescAddress(mlir::Location loc, TypePair boxTy, mlir::Value box,
125                       mlir::ConversionPatternRewriter &rewriter) const;
126 
127   // Load the attribute from the \p box and perform a check against \p maskValue
128   // The final comparison is implemented as `(attribute & maskValue) != 0`.
129   mlir::Value genBoxAttributeCheck(mlir::Location loc, TypePair boxTy,
130                                    mlir::Value box,
131                                    mlir::ConversionPatternRewriter &rewriter,
132                                    unsigned maskValue) const;
133 
134   /// Compute the descriptor size in bytes. The result is not guaranteed to be a
135   /// compile time constant if the box is for an assumed rank, in which case the
136   /// box rank will be read.
137   mlir::Value computeBoxSize(mlir::Location, TypePair boxTy, mlir::Value box,
138                              mlir::ConversionPatternRewriter &rewriter) const;
139 
140   template <typename... ARGS>
141   mlir::LLVM::GEPOp genGEP(mlir::Location loc, mlir::Type ty,
142                            mlir::ConversionPatternRewriter &rewriter,
143                            mlir::Value base, ARGS... args) const {
144     llvm::SmallVector<mlir::LLVM::GEPArg> cv = {args...};
145     auto llvmPtrTy =
146         mlir::LLVM::LLVMPointerType::get(ty.getContext(), /*addressSpace=*/0);
147     return rewriter.create<mlir::LLVM::GEPOp>(loc, llvmPtrTy, ty, base, cv);
148   }
149 
150   // Find the Block in which the alloca should be inserted.
151   // The order to recursively find the proper block:
152   // 1. An OpenMP Op that will be outlined.
153   // 2. An OpenMP or OpenACC Op with one or more regions holding executable
154   // code.
155   // 3. A LLVMFuncOp
156   // 4. The first ancestor that is one of the above.
157   mlir::Block *getBlockForAllocaInsert(mlir::Operation *op,
158                                        mlir::Region *parentRegion) const;
159 
160   // Generate an alloca of size 1 for an object of type \p llvmObjectTy in the
161   // allocation address space provided for the architecture in the DataLayout
162   // specification. If the address space is different from the devices
163   // program address space we perform a cast. In the case of most architectures
164   // the program and allocation address space will be the default of 0 and no
165   // cast will be emitted.
166   mlir::Value
167   genAllocaAndAddrCastWithType(mlir::Location loc, mlir::Type llvmObjectTy,
168                                unsigned alignment,
169                                mlir::ConversionPatternRewriter &rewriter) const;
170 
171   const fir::LLVMTypeConverter &lowerTy() const {
172     return *static_cast<const fir::LLVMTypeConverter *>(
173         this->getTypeConverter());
174   }
175 
176   void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
177                      mlir::Type baseFIRType, mlir::Type accessFIRType,
178                      mlir::LLVM::GEPOp gep) const {
179     lowerTy().attachTBAATag(op, baseFIRType, accessFIRType, gep);
180   }
181 
182   unsigned
183   getAllocaAddressSpace(mlir::ConversionPatternRewriter &rewriter) const;
184 
185   unsigned
186   getProgramAddressSpace(mlir::ConversionPatternRewriter &rewriter) const;
187 
188   const fir::FIRToLLVMPassOptions &options;
189 
190   using ConvertToLLVMPattern::match;
191   using ConvertToLLVMPattern::matchAndRewrite;
192 };
193 
194 template <typename SourceOp>
195 class FIROpConversion : public ConvertFIRToLLVMPattern {
196 public:
197   using OpAdaptor = typename SourceOp::Adaptor;
198   using OneToNOpAdaptor = typename SourceOp::template GenericAdaptor<
199       mlir::ArrayRef<mlir::ValueRange>>;
200 
201   explicit FIROpConversion(const LLVMTypeConverter &typeConverter,
202                            const fir::FIRToLLVMPassOptions &options,
203                            mlir::PatternBenefit benefit = 1)
204       : ConvertFIRToLLVMPattern(SourceOp::getOperationName(),
205                                 &typeConverter.getContext(), typeConverter,
206                                 options, benefit) {}
207 
208   /// Wrappers around the RewritePattern methods that pass the derived op type.
209   void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
210                mlir::ConversionPatternRewriter &rewriter) const final {
211     rewrite(mlir::cast<SourceOp>(op),
212             OpAdaptor(operands, mlir::cast<SourceOp>(op)), rewriter);
213   }
214   void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::ValueRange> operands,
215                mlir::ConversionPatternRewriter &rewriter) const final {
216     auto sourceOp = llvm::cast<SourceOp>(op);
217     rewrite(llvm::cast<SourceOp>(op), OneToNOpAdaptor(operands, sourceOp),
218             rewriter);
219   }
220   llvm::LogicalResult match(mlir::Operation *op) const final {
221     return match(mlir::cast<SourceOp>(op));
222   }
223   llvm::LogicalResult
224   matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
225                   mlir::ConversionPatternRewriter &rewriter) const final {
226     return matchAndRewrite(mlir::cast<SourceOp>(op),
227                            OpAdaptor(operands, mlir::cast<SourceOp>(op)),
228                            rewriter);
229   }
230   llvm::LogicalResult
231   matchAndRewrite(mlir::Operation *op,
232                   mlir::ArrayRef<mlir::ValueRange> operands,
233                   mlir::ConversionPatternRewriter &rewriter) const final {
234     auto sourceOp = mlir::cast<SourceOp>(op);
235     return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
236                            rewriter);
237   }
238   /// Rewrite and Match methods that operate on the SourceOp type. These must be
239   /// overridden by the derived pattern class.
240   virtual llvm::LogicalResult match(SourceOp op) const {
241     llvm_unreachable("must override match or matchAndRewrite");
242   }
243   virtual void rewrite(SourceOp op, OpAdaptor adaptor,
244                        mlir::ConversionPatternRewriter &rewriter) const {
245     llvm_unreachable("must override rewrite or matchAndRewrite");
246   }
247   virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
248                        mlir::ConversionPatternRewriter &rewriter) const {
249     llvm::SmallVector<mlir::Value> oneToOneOperands =
250         getOneToOneAdaptorOperands(adaptor.getOperands());
251     rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
252   }
253   virtual llvm::LogicalResult
254   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
255                   mlir::ConversionPatternRewriter &rewriter) const {
256     if (mlir::failed(match(op)))
257       return mlir::failure();
258     rewrite(op, adaptor, rewriter);
259     return mlir::success();
260   }
261   virtual llvm::LogicalResult
262   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
263                   mlir::ConversionPatternRewriter &rewriter) const {
264     llvm::SmallVector<mlir::Value> oneToOneOperands =
265         getOneToOneAdaptorOperands(adaptor.getOperands());
266     return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
267   }
268 
269 private:
270   using ConvertFIRToLLVMPattern::matchAndRewrite;
271   using ConvertToLLVMPattern::match;
272 };
273 
274 /// FIR conversion pattern template
275 template <typename FromOp>
276 class FIROpAndTypeConversion : public FIROpConversion<FromOp> {
277 public:
278   using FIROpConversion<FromOp>::FIROpConversion;
279   using OpAdaptor = typename FromOp::Adaptor;
280 
281   llvm::LogicalResult
282   matchAndRewrite(FromOp op, OpAdaptor adaptor,
283                   mlir::ConversionPatternRewriter &rewriter) const final {
284     mlir::Type ty = this->convertType(op.getType());
285     return doRewrite(op, ty, adaptor, rewriter);
286   }
287 
288   virtual llvm::LogicalResult
289   doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor,
290             mlir::ConversionPatternRewriter &rewriter) const = 0;
291 };
292 
293 } // namespace fir
294 
295 #endif // FORTRAN_OPTIMIZER_CODEGEN_FIROPPATTERNS_H
296