xref: /llvm-project/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp (revision 8fc9b4efd25287d6ef87e4e25ea40789401ecf7f)
1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/CodeGen/FIROpPatterns.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15 #include "llvm/Support/Debug.h"
16 
17 static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context,
18                                         unsigned addressSpace = 0) {
19   return mlir::LLVM::LLVMPointerType::get(context, addressSpace);
20 }
21 
22 static unsigned getTypeDescFieldId(mlir::Type ty) {
23   auto isArray = mlir::isa<fir::SequenceType>(fir::dyn_cast_ptrOrBoxEleTy(ty));
24   return isArray ? kOptTypePtrPosInBox : kDimsPosInBox;
25 }
26 
27 namespace fir {
28 
29 ConvertFIRToLLVMPattern::ConvertFIRToLLVMPattern(
30     llvm::StringRef rootOpName, mlir::MLIRContext *context,
31     const fir::LLVMTypeConverter &typeConverter,
32     const fir::FIRToLLVMPassOptions &options, mlir::PatternBenefit benefit)
33     : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
34       options(options) {}
35 
36 // Convert FIR type to LLVM without turning fir.box<T> into memory
37 // reference.
38 mlir::Type
39 ConvertFIRToLLVMPattern::convertObjectType(mlir::Type firType) const {
40   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
41     return lowerTy().convertBoxTypeAsStruct(boxTy);
42   return lowerTy().convertType(firType);
43 }
44 
45 mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genI32Constant(
46     mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
47     int value) const {
48   mlir::Type i32Ty = rewriter.getI32Type();
49   mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value);
50   return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
51 }
52 
53 mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genConstantOffset(
54     mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
55     int offset) const {
56   mlir::Type ity = lowerTy().offsetType();
57   mlir::IntegerAttr cattr = rewriter.getI32IntegerAttr(offset);
58   return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
59 }
60 
61 /// Perform an extension or truncation as needed on an integer value. Lowering
62 /// to the specific target may involve some sign-extending or truncation of
63 /// values, particularly to fit them from abstract box types to the
64 /// appropriate reified structures.
65 mlir::Value ConvertFIRToLLVMPattern::integerCast(
66     mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
67     mlir::Type ty, mlir::Value val, bool fold) const {
68   auto valTy = val.getType();
69   // If the value was not yet lowered, lower its type so that it can
70   // be used in getPrimitiveTypeSizeInBits.
71   if (!mlir::isa<mlir::IntegerType>(valTy))
72     valTy = convertType(valTy);
73   auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
74   auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
75   if (fold) {
76     if (toSize < fromSize)
77       return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val);
78     if (toSize > fromSize)
79       return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val);
80   } else {
81     if (toSize < fromSize)
82       return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
83     if (toSize > fromSize)
84       return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
85   }
86   return val;
87 }
88 
89 fir::ConvertFIRToLLVMPattern::TypePair
90 ConvertFIRToLLVMPattern::getBoxTypePair(mlir::Type firBoxTy) const {
91   mlir::Type llvmBoxTy =
92       lowerTy().convertBoxTypeAsStruct(mlir::cast<fir::BaseBoxType>(firBoxTy));
93   return TypePair{firBoxTy, llvmBoxTy};
94 }
95 
96 /// Construct code sequence to extract the specific value from a `fir.box`.
97 mlir::Value ConvertFIRToLLVMPattern::getValueFromBox(
98     mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Type resultTy,
99     mlir::ConversionPatternRewriter &rewriter, int boxValue) const {
100   if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) {
101     auto pty = getLlvmPtrType(resultTy.getContext());
102     auto p = rewriter.create<mlir::LLVM::GEPOp>(
103         loc, pty, boxTy.llvm, box,
104         llvm::ArrayRef<mlir::LLVM::GEPArg>{0, boxValue});
105     auto fldTy = getBoxEleTy(boxTy.llvm, {boxValue});
106     auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, fldTy, p);
107     auto castOp = integerCast(loc, rewriter, resultTy, loadOp);
108     attachTBAATag(loadOp, boxTy.fir, nullptr, p);
109     return castOp;
110   }
111   return rewriter.create<mlir::LLVM::ExtractValueOp>(loc, box, boxValue);
112 }
113 
114 /// Method to construct code sequence to get the triple for dimension `dim`
115 /// from a box.
116 llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox(
117     mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy,
118     mlir::Value box, mlir::Value dim,
119     mlir::ConversionPatternRewriter &rewriter) const {
120   mlir::Value l0 =
121       loadDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter);
122   mlir::Value l1 =
123       loadDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter);
124   mlir::Value l2 =
125       loadDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter);
126   return {l0, l1, l2};
127 }
128 
129 llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox(
130     mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy,
131     mlir::Value box, int dim, mlir::ConversionPatternRewriter &rewriter) const {
132   mlir::Value l0 =
133       getDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter);
134   mlir::Value l1 =
135       getDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter);
136   mlir::Value l2 =
137       getDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter);
138   return {l0, l1, l2};
139 }
140 
141 mlir::Value ConvertFIRToLLVMPattern::loadDimFieldFromBox(
142     mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Value dim,
143     int off, mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const {
144   assert(mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType()) &&
145          "descriptor inquiry with runtime dim can only be done on descriptor "
146          "in memory");
147   mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0,
148                                static_cast<int>(kDimsPosInBox), dim, off);
149   auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p);
150   attachTBAATag(loadOp, boxTy.fir, nullptr, p);
151   return loadOp;
152 }
153 
154 mlir::Value ConvertFIRToLLVMPattern::getDimFieldFromBox(
155     mlir::Location loc, TypePair boxTy, mlir::Value box, int dim, int off,
156     mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const {
157   if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) {
158     mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0,
159                                  static_cast<int>(kDimsPosInBox), dim, off);
160     auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p);
161     attachTBAATag(loadOp, boxTy.fir, nullptr, p);
162     return loadOp;
163   }
164   return rewriter.create<mlir::LLVM::ExtractValueOp>(
165       loc, box, llvm::ArrayRef<std::int64_t>{kDimsPosInBox, dim, off});
166 }
167 
168 mlir::Value ConvertFIRToLLVMPattern::getStrideFromBox(
169     mlir::Location loc, TypePair boxTy, mlir::Value box, unsigned dim,
170     mlir::ConversionPatternRewriter &rewriter) const {
171   auto idxTy = lowerTy().indexType();
172   return getDimFieldFromBox(loc, boxTy, box, dim, kDimStridePos, idxTy,
173                             rewriter);
174 }
175 
176 /// Read base address from a fir.box. Returned address has type ty.
177 mlir::Value ConvertFIRToLLVMPattern::getBaseAddrFromBox(
178     mlir::Location loc, TypePair boxTy, mlir::Value box,
179     mlir::ConversionPatternRewriter &rewriter) const {
180   mlir::Type resultTy = ::getLlvmPtrType(boxTy.llvm.getContext());
181   return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kAddrPosInBox);
182 }
183 
184 mlir::Value ConvertFIRToLLVMPattern::getElementSizeFromBox(
185     mlir::Location loc, mlir::Type resultTy, TypePair boxTy, mlir::Value box,
186     mlir::ConversionPatternRewriter &rewriter) const {
187   return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kElemLenPosInBox);
188 }
189 
190 /// Read base address from a fir.box. Returned address has type ty.
191 mlir::Value ConvertFIRToLLVMPattern::getRankFromBox(
192     mlir::Location loc, TypePair boxTy, mlir::Value box,
193     mlir::ConversionPatternRewriter &rewriter) const {
194   mlir::Type resultTy = getBoxEleTy(boxTy.llvm, {kRankPosInBox});
195   return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kRankPosInBox);
196 }
197 
198 // Get the element type given an LLVM type that is of the form
199 // (array|struct|vector)+ and the provided indexes.
200 mlir::Type ConvertFIRToLLVMPattern::getBoxEleTy(
201     mlir::Type type, llvm::ArrayRef<std::int64_t> indexes) const {
202   for (unsigned i : indexes) {
203     if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMStructType>(type)) {
204       assert(!t.isOpaque() && i < t.getBody().size());
205       type = t.getBody()[i];
206     } else if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) {
207       type = t.getElementType();
208     } else if (auto t = mlir::dyn_cast<mlir::VectorType>(type)) {
209       type = t.getElementType();
210     } else {
211       fir::emitFatalError(mlir::UnknownLoc::get(type.getContext()),
212                           "request for invalid box element type");
213     }
214   }
215   return type;
216 }
217 
218 // Return LLVM type of the object described by a fir.box of \p boxType.
219 mlir::Type ConvertFIRToLLVMPattern::getLlvmObjectTypeFromBoxType(
220     mlir::Type boxType) const {
221   mlir::Type objectType = fir::dyn_cast_ptrOrBoxEleTy(boxType);
222   assert(objectType && "boxType must be a box type");
223   return this->convertType(objectType);
224 }
225 
226 /// Read the address of the type descriptor from a box.
227 mlir::Value ConvertFIRToLLVMPattern::loadTypeDescAddress(
228     mlir::Location loc, TypePair boxTy, mlir::Value box,
229     mlir::ConversionPatternRewriter &rewriter) const {
230   unsigned typeDescFieldId = getTypeDescFieldId(boxTy.fir);
231   mlir::Type tdescType = lowerTy().convertTypeDescType(rewriter.getContext());
232   return getValueFromBox(loc, boxTy, box, tdescType, rewriter, typeDescFieldId);
233 }
234 
235 // Load the attribute from the \p box and perform a check against \p maskValue
236 // The final comparison is implemented as `(attribute & maskValue) != 0`.
237 mlir::Value ConvertFIRToLLVMPattern::genBoxAttributeCheck(
238     mlir::Location loc, TypePair boxTy, mlir::Value box,
239     mlir::ConversionPatternRewriter &rewriter, unsigned maskValue) const {
240   mlir::Type attrTy = rewriter.getI32Type();
241   mlir::Value attribute =
242       getValueFromBox(loc, boxTy, box, attrTy, rewriter, kAttributePosInBox);
243   mlir::LLVM::ConstantOp attrMask = genConstantOffset(loc, rewriter, maskValue);
244   auto maskRes =
245       rewriter.create<mlir::LLVM::AndOp>(loc, attrTy, attribute, attrMask);
246   mlir::LLVM::ConstantOp c0 = genConstantOffset(loc, rewriter, 0);
247   return rewriter.create<mlir::LLVM::ICmpOp>(loc, mlir::LLVM::ICmpPredicate::ne,
248                                              maskRes, c0);
249 }
250 
251 mlir::Value ConvertFIRToLLVMPattern::computeBoxSize(
252     mlir::Location loc, TypePair boxTy, mlir::Value box,
253     mlir::ConversionPatternRewriter &rewriter) const {
254   auto firBoxType = mlir::dyn_cast<fir::BaseBoxType>(boxTy.fir);
255   assert(firBoxType && "must be a BaseBoxType");
256   const mlir::DataLayout &dl = lowerTy().getDataLayout();
257   if (!firBoxType.isAssumedRank())
258     return genConstantOffset(loc, rewriter, dl.getTypeSize(boxTy.llvm));
259   fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0);
260   mlir::Type llvmScalarBoxType =
261       lowerTy().convertBoxTypeAsStruct(firScalarBoxType);
262   llvm::TypeSize scalarBoxSizeCst = dl.getTypeSize(llvmScalarBoxType);
263   mlir::Value scalarBoxSize =
264       genConstantOffset(loc, rewriter, scalarBoxSizeCst);
265   mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter);
266   mlir::Value rank =
267       integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank);
268   mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1});
269   llvm::TypeSize sizePerDimCst = dl.getTypeSize(llvmDimsType);
270   assert((scalarBoxSizeCst + sizePerDimCst ==
271           dl.getTypeSize(lowerTy().convertBoxTypeAsStruct(
272               firBoxType.getBoxTypeWithNewShape(1)))) &&
273          "descriptor layout requires adding padding for dim field");
274   mlir::Value sizePerDim = genConstantOffset(loc, rewriter, sizePerDimCst);
275   mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>(
276       loc, sizePerDim.getType(), sizePerDim, rank);
277   mlir::Value size = rewriter.create<mlir::LLVM::AddOp>(
278       loc, scalarBoxSize.getType(), scalarBoxSize, dimsSize);
279   return size;
280 }
281 
282 // Find the Block in which the alloca should be inserted.
283 // The order to recursively find the proper block:
284 // 1. An OpenMP Op that will be outlined.
285 // 2. An OpenMP or OpenACC Op with one or more regions holding executable code.
286 // 3. A LLVMFuncOp
287 // 4. The first ancestor that is one of the above.
288 mlir::Block *ConvertFIRToLLVMPattern::getBlockForAllocaInsert(
289     mlir::Operation *op, mlir::Region *parentRegion) const {
290   if (auto iface = mlir::dyn_cast<mlir::omp::OutlineableOpenMPOpInterface>(op))
291     return iface.getAllocaBlock();
292   if (auto recipeIface = mlir::dyn_cast<mlir::accomp::RecipeInterface>(op))
293     return recipeIface.getAllocaBlock(*parentRegion);
294   if (auto llvmFuncOp = mlir::dyn_cast<mlir::LLVM::LLVMFuncOp>(op))
295     return &llvmFuncOp.front();
296 
297   return getBlockForAllocaInsert(op->getParentOp(), parentRegion);
298 }
299 
300 // Generate an alloca of size 1 for an object of type \p llvmObjectTy in the
301 // allocation address space provided for the architecture in the DataLayout
302 // specification. If the address space is different from the devices
303 // program address space we perform a cast. In the case of most architectures
304 // the program and allocation address space will be the default of 0 and no
305 // cast will be emitted.
306 mlir::Value ConvertFIRToLLVMPattern::genAllocaAndAddrCastWithType(
307     mlir::Location loc, mlir::Type llvmObjectTy, unsigned alignment,
308     mlir::ConversionPatternRewriter &rewriter) const {
309   auto thisPt = rewriter.saveInsertionPoint();
310   mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
311   mlir::Region *parentRegion = rewriter.getInsertionBlock()->getParent();
312   mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp, parentRegion);
313   rewriter.setInsertionPointToStart(insertBlock);
314   auto size = genI32Constant(loc, rewriter, 1);
315   unsigned allocaAs = getAllocaAddressSpace(rewriter);
316   unsigned programAs = getProgramAddressSpace(rewriter);
317 
318   mlir::Value al = rewriter.create<mlir::LLVM::AllocaOp>(
319       loc, ::getLlvmPtrType(llvmObjectTy.getContext(), allocaAs), llvmObjectTy,
320       size, alignment);
321 
322   // if our allocation address space, is not the same as the program address
323   // space, then we must emit a cast to the program address space before use.
324   // An example case would be on AMDGPU, where the allocation address space is
325   // the numeric value 5 (private), and the program address space is 0
326   // (generic).
327   if (allocaAs != programAs) {
328     al = rewriter.create<mlir::LLVM::AddrSpaceCastOp>(
329         loc, ::getLlvmPtrType(llvmObjectTy.getContext(), programAs), al);
330   }
331 
332   rewriter.restoreInsertionPoint(thisPt);
333   return al;
334 }
335 
336 unsigned ConvertFIRToLLVMPattern::getAllocaAddressSpace(
337     mlir::ConversionPatternRewriter &rewriter) const {
338   mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
339   assert(parentOp != nullptr &&
340          "expected insertion block to have parent operation");
341   if (auto module = parentOp->getParentOfType<mlir::ModuleOp>())
342     if (mlir::Attribute addrSpace =
343             mlir::DataLayout(module).getAllocaMemorySpace())
344       return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
345   return defaultAddressSpace;
346 }
347 
348 unsigned ConvertFIRToLLVMPattern::getProgramAddressSpace(
349     mlir::ConversionPatternRewriter &rewriter) const {
350   mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
351   assert(parentOp != nullptr &&
352          "expected insertion block to have parent operation");
353   if (auto module = parentOp->getParentOfType<mlir::ModuleOp>())
354     if (mlir::Attribute addrSpace =
355             mlir::DataLayout(module).getProgramMemorySpace())
356       return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
357   return defaultAddressSpace;
358 }
359 
360 } // namespace fir
361