xref: /llvm-project/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp (revision 26e0ce0b3633c67e09d2f3a99e0d4058a4e0a887)
1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/CodeGen/FIROpPatterns.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15 #include "llvm/Support/Debug.h"
16 
17 static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context,
18                                         unsigned addressSpace = 0) {
19   return mlir::LLVM::LLVMPointerType::get(context, addressSpace);
20 }
21 
22 static unsigned getTypeDescFieldId(mlir::Type ty) {
23   auto isArray = mlir::isa<fir::SequenceType>(fir::dyn_cast_ptrOrBoxEleTy(ty));
24   return isArray ? kOptTypePtrPosInBox : kDimsPosInBox;
25 }
26 
27 namespace fir {
28 
29 ConvertFIRToLLVMPattern::ConvertFIRToLLVMPattern(
30     llvm::StringRef rootOpName, mlir::MLIRContext *context,
31     const fir::LLVMTypeConverter &typeConverter,
32     const fir::FIRToLLVMPassOptions &options, mlir::PatternBenefit benefit)
33     : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
34       options(options) {}
35 
36 // Convert FIR type to LLVM without turning fir.box<T> into memory
37 // reference.
38 mlir::Type
39 ConvertFIRToLLVMPattern::convertObjectType(mlir::Type firType) const {
40   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
41     return lowerTy().convertBoxTypeAsStruct(boxTy);
42   return lowerTy().convertType(firType);
43 }
44 
45 mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genI32Constant(
46     mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
47     int value) const {
48   mlir::Type i32Ty = rewriter.getI32Type();
49   mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value);
50   return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
51 }
52 
53 mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genConstantOffset(
54     mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
55     int offset) const {
56   mlir::Type ity = lowerTy().offsetType();
57   mlir::IntegerAttr cattr = rewriter.getI32IntegerAttr(offset);
58   return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
59 }
60 
61 /// Perform an extension or truncation as needed on an integer value. Lowering
62 /// to the specific target may involve some sign-extending or truncation of
63 /// values, particularly to fit them from abstract box types to the
64 /// appropriate reified structures.
65 mlir::Value
66 ConvertFIRToLLVMPattern::integerCast(mlir::Location loc,
67                                      mlir::ConversionPatternRewriter &rewriter,
68                                      mlir::Type ty, mlir::Value val) const {
69   auto valTy = val.getType();
70   // If the value was not yet lowered, lower its type so that it can
71   // be used in getPrimitiveTypeSizeInBits.
72   if (!mlir::isa<mlir::IntegerType>(valTy))
73     valTy = convertType(valTy);
74   auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
75   auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
76   if (toSize < fromSize)
77     return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
78   if (toSize > fromSize)
79     return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
80   return val;
81 }
82 
83 fir::ConvertFIRToLLVMPattern::TypePair
84 ConvertFIRToLLVMPattern::getBoxTypePair(mlir::Type firBoxTy) const {
85   mlir::Type llvmBoxTy =
86       lowerTy().convertBoxTypeAsStruct(mlir::cast<fir::BaseBoxType>(firBoxTy));
87   return TypePair{firBoxTy, llvmBoxTy};
88 }
89 
90 /// Construct code sequence to extract the specific value from a `fir.box`.
91 mlir::Value ConvertFIRToLLVMPattern::getValueFromBox(
92     mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Type resultTy,
93     mlir::ConversionPatternRewriter &rewriter, int boxValue) const {
94   if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) {
95     auto pty = getLlvmPtrType(resultTy.getContext());
96     auto p = rewriter.create<mlir::LLVM::GEPOp>(
97         loc, pty, boxTy.llvm, box,
98         llvm::ArrayRef<mlir::LLVM::GEPArg>{0, boxValue});
99     auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, resultTy, p);
100     attachTBAATag(loadOp, boxTy.fir, nullptr, p);
101     return loadOp;
102   }
103   return rewriter.create<mlir::LLVM::ExtractValueOp>(loc, box, boxValue);
104 }
105 
106 /// Method to construct code sequence to get the triple for dimension `dim`
107 /// from a box.
108 llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox(
109     mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy,
110     mlir::Value box, mlir::Value dim,
111     mlir::ConversionPatternRewriter &rewriter) const {
112   mlir::Value l0 =
113       loadDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter);
114   mlir::Value l1 =
115       loadDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter);
116   mlir::Value l2 =
117       loadDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter);
118   return {l0, l1, l2};
119 }
120 
121 llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox(
122     mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy,
123     mlir::Value box, int dim, mlir::ConversionPatternRewriter &rewriter) const {
124   mlir::Value l0 =
125       getDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter);
126   mlir::Value l1 =
127       getDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter);
128   mlir::Value l2 =
129       getDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter);
130   return {l0, l1, l2};
131 }
132 
133 mlir::Value ConvertFIRToLLVMPattern::loadDimFieldFromBox(
134     mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Value dim,
135     int off, mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const {
136   assert(mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType()) &&
137          "descriptor inquiry with runtime dim can only be done on descriptor "
138          "in memory");
139   mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0,
140                                static_cast<int>(kDimsPosInBox), dim, off);
141   auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p);
142   attachTBAATag(loadOp, boxTy.fir, nullptr, p);
143   return loadOp;
144 }
145 
146 mlir::Value ConvertFIRToLLVMPattern::getDimFieldFromBox(
147     mlir::Location loc, TypePair boxTy, mlir::Value box, int dim, int off,
148     mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const {
149   if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) {
150     mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0,
151                                  static_cast<int>(kDimsPosInBox), dim, off);
152     auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p);
153     attachTBAATag(loadOp, boxTy.fir, nullptr, p);
154     return loadOp;
155   }
156   return rewriter.create<mlir::LLVM::ExtractValueOp>(
157       loc, box, llvm::ArrayRef<std::int64_t>{kDimsPosInBox, dim, off});
158 }
159 
160 mlir::Value ConvertFIRToLLVMPattern::getStrideFromBox(
161     mlir::Location loc, TypePair boxTy, mlir::Value box, unsigned dim,
162     mlir::ConversionPatternRewriter &rewriter) const {
163   auto idxTy = lowerTy().indexType();
164   return getDimFieldFromBox(loc, boxTy, box, dim, kDimStridePos, idxTy,
165                             rewriter);
166 }
167 
168 /// Read base address from a fir.box. Returned address has type ty.
169 mlir::Value ConvertFIRToLLVMPattern::getBaseAddrFromBox(
170     mlir::Location loc, TypePair boxTy, mlir::Value box,
171     mlir::ConversionPatternRewriter &rewriter) const {
172   mlir::Type resultTy = ::getLlvmPtrType(boxTy.llvm.getContext());
173   return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kAddrPosInBox);
174 }
175 
176 mlir::Value ConvertFIRToLLVMPattern::getElementSizeFromBox(
177     mlir::Location loc, mlir::Type resultTy, TypePair boxTy, mlir::Value box,
178     mlir::ConversionPatternRewriter &rewriter) const {
179   return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kElemLenPosInBox);
180 }
181 
182 /// Read base address from a fir.box. Returned address has type ty.
183 mlir::Value ConvertFIRToLLVMPattern::getRankFromBox(
184     mlir::Location loc, TypePair boxTy, mlir::Value box,
185     mlir::ConversionPatternRewriter &rewriter) const {
186   mlir::Type resultTy = getBoxEleTy(boxTy.llvm, {kRankPosInBox});
187   return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kRankPosInBox);
188 }
189 
190 // Get the element type given an LLVM type that is of the form
191 // (array|struct|vector)+ and the provided indexes.
192 mlir::Type ConvertFIRToLLVMPattern::getBoxEleTy(
193     mlir::Type type, llvm::ArrayRef<std::int64_t> indexes) const {
194   for (unsigned i : indexes) {
195     if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMStructType>(type)) {
196       assert(!t.isOpaque() && i < t.getBody().size());
197       type = t.getBody()[i];
198     } else if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) {
199       type = t.getElementType();
200     } else if (auto t = mlir::dyn_cast<mlir::VectorType>(type)) {
201       type = t.getElementType();
202     } else {
203       fir::emitFatalError(mlir::UnknownLoc::get(type.getContext()),
204                           "request for invalid box element type");
205     }
206   }
207   return type;
208 }
209 
210 // Return LLVM type of the object described by a fir.box of \p boxType.
211 mlir::Type ConvertFIRToLLVMPattern::getLlvmObjectTypeFromBoxType(
212     mlir::Type boxType) const {
213   mlir::Type objectType = fir::dyn_cast_ptrOrBoxEleTy(boxType);
214   assert(objectType && "boxType must be a box type");
215   return this->convertType(objectType);
216 }
217 
218 /// Read the address of the type descriptor from a box.
219 mlir::Value ConvertFIRToLLVMPattern::loadTypeDescAddress(
220     mlir::Location loc, TypePair boxTy, mlir::Value box,
221     mlir::ConversionPatternRewriter &rewriter) const {
222   unsigned typeDescFieldId = getTypeDescFieldId(boxTy.fir);
223   mlir::Type tdescType = lowerTy().convertTypeDescType(rewriter.getContext());
224   return getValueFromBox(loc, boxTy, box, tdescType, rewriter, typeDescFieldId);
225 }
226 
227 // Load the attribute from the \p box and perform a check against \p maskValue
228 // The final comparison is implemented as `(attribute & maskValue) != 0`.
229 mlir::Value ConvertFIRToLLVMPattern::genBoxAttributeCheck(
230     mlir::Location loc, TypePair boxTy, mlir::Value box,
231     mlir::ConversionPatternRewriter &rewriter, unsigned maskValue) const {
232   mlir::Type attrTy = rewriter.getI32Type();
233   mlir::Value attribute =
234       getValueFromBox(loc, boxTy, box, attrTy, rewriter, kAttributePosInBox);
235   mlir::LLVM::ConstantOp attrMask = genConstantOffset(loc, rewriter, maskValue);
236   auto maskRes =
237       rewriter.create<mlir::LLVM::AndOp>(loc, attrTy, attribute, attrMask);
238   mlir::LLVM::ConstantOp c0 = genConstantOffset(loc, rewriter, 0);
239   return rewriter.create<mlir::LLVM::ICmpOp>(loc, mlir::LLVM::ICmpPredicate::ne,
240                                              maskRes, c0);
241 }
242 
243 // Find the Block in which the alloca should be inserted.
244 // The order to recursively find the proper block:
245 // 1. An OpenMP Op that will be outlined.
246 // 2. A LLVMFuncOp
247 // 3. The first ancestor that is an OpenMP Op or a LLVMFuncOp
248 mlir::Block *
249 ConvertFIRToLLVMPattern::getBlockForAllocaInsert(mlir::Operation *op) const {
250   if (auto iface = mlir::dyn_cast<mlir::omp::OutlineableOpenMPOpInterface>(op))
251     return iface.getAllocaBlock();
252   if (auto llvmFuncOp = mlir::dyn_cast<mlir::LLVM::LLVMFuncOp>(op))
253     return &llvmFuncOp.front();
254 
255   return getBlockForAllocaInsert(op->getParentOp());
256 }
257 
258 // Generate an alloca of size 1 for an object of type \p llvmObjectTy in the
259 // allocation address space provided for the architecture in the DataLayout
260 // specification. If the address space is different from the devices
261 // program address space we perform a cast. In the case of most architectures
262 // the program and allocation address space will be the default of 0 and no
263 // cast will be emitted.
264 mlir::Value ConvertFIRToLLVMPattern::genAllocaAndAddrCastWithType(
265     mlir::Location loc, mlir::Type llvmObjectTy, unsigned alignment,
266     mlir::ConversionPatternRewriter &rewriter) const {
267   auto thisPt = rewriter.saveInsertionPoint();
268   mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
269   if (mlir::isa<mlir::omp::DeclareReductionOp>(parentOp) ||
270       mlir::isa<mlir::omp::PrivateClauseOp>(parentOp)) {
271     // DeclareReductionOp & PrivateClauseOp have multiple child regions. We want
272     // to get the first block of whichever of those regions we are currently in
273     mlir::Region *parentRegion = rewriter.getInsertionBlock()->getParent();
274     rewriter.setInsertionPointToStart(&parentRegion->front());
275   } else {
276     mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp);
277     rewriter.setInsertionPointToStart(insertBlock);
278   }
279   auto size = genI32Constant(loc, rewriter, 1);
280   unsigned allocaAs = getAllocaAddressSpace(rewriter);
281   unsigned programAs = getProgramAddressSpace(rewriter);
282 
283   mlir::Value al = rewriter.create<mlir::LLVM::AllocaOp>(
284       loc, ::getLlvmPtrType(llvmObjectTy.getContext(), allocaAs), llvmObjectTy,
285       size, alignment);
286 
287   // if our allocation address space, is not the same as the program address
288   // space, then we must emit a cast to the program address space before use.
289   // An example case would be on AMDGPU, where the allocation address space is
290   // the numeric value 5 (private), and the program address space is 0
291   // (generic).
292   if (allocaAs != programAs) {
293     al = rewriter.create<mlir::LLVM::AddrSpaceCastOp>(
294         loc, ::getLlvmPtrType(llvmObjectTy.getContext(), programAs), al);
295   }
296 
297   rewriter.restoreInsertionPoint(thisPt);
298   return al;
299 }
300 
301 unsigned ConvertFIRToLLVMPattern::getAllocaAddressSpace(
302     mlir::ConversionPatternRewriter &rewriter) const {
303   mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
304   assert(parentOp != nullptr &&
305          "expected insertion block to have parent operation");
306   if (auto module = parentOp->getParentOfType<mlir::ModuleOp>())
307     if (mlir::Attribute addrSpace =
308             mlir::DataLayout(module).getAllocaMemorySpace())
309       return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
310   return defaultAddressSpace;
311 }
312 
313 unsigned ConvertFIRToLLVMPattern::getProgramAddressSpace(
314     mlir::ConversionPatternRewriter &rewriter) const {
315   mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
316   assert(parentOp != nullptr &&
317          "expected insertion block to have parent operation");
318   if (auto module = parentOp->getParentOfType<mlir::ModuleOp>())
319     if (mlir::Attribute addrSpace =
320             mlir::DataLayout(module).getProgramMemorySpace())
321       return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
322   return defaultAddressSpace;
323 }
324 
325 } // namespace fir
326