1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/CodeGen/FIROpPatterns.h" 14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 15 #include "llvm/Support/Debug.h" 16 17 static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context, 18 unsigned addressSpace = 0) { 19 return mlir::LLVM::LLVMPointerType::get(context, addressSpace); 20 } 21 22 static unsigned getTypeDescFieldId(mlir::Type ty) { 23 auto isArray = mlir::isa<fir::SequenceType>(fir::dyn_cast_ptrOrBoxEleTy(ty)); 24 return isArray ? kOptTypePtrPosInBox : kDimsPosInBox; 25 } 26 27 namespace fir { 28 29 ConvertFIRToLLVMPattern::ConvertFIRToLLVMPattern( 30 llvm::StringRef rootOpName, mlir::MLIRContext *context, 31 const fir::LLVMTypeConverter &typeConverter, 32 const fir::FIRToLLVMPassOptions &options, mlir::PatternBenefit benefit) 33 : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit), 34 options(options) {} 35 36 // Convert FIR type to LLVM without turning fir.box<T> into memory 37 // reference. 38 mlir::Type 39 ConvertFIRToLLVMPattern::convertObjectType(mlir::Type firType) const { 40 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType)) 41 return lowerTy().convertBoxTypeAsStruct(boxTy); 42 return lowerTy().convertType(firType); 43 } 44 45 mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genI32Constant( 46 mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, 47 int value) const { 48 mlir::Type i32Ty = rewriter.getI32Type(); 49 mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); 50 return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr); 51 } 52 53 mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genConstantOffset( 54 mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, 55 int offset) const { 56 mlir::Type ity = lowerTy().offsetType(); 57 mlir::IntegerAttr cattr = rewriter.getI32IntegerAttr(offset); 58 return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr); 59 } 60 61 /// Perform an extension or truncation as needed on an integer value. Lowering 62 /// to the specific target may involve some sign-extending or truncation of 63 /// values, particularly to fit them from abstract box types to the 64 /// appropriate reified structures. 65 mlir::Value 66 ConvertFIRToLLVMPattern::integerCast(mlir::Location loc, 67 mlir::ConversionPatternRewriter &rewriter, 68 mlir::Type ty, mlir::Value val) const { 69 auto valTy = val.getType(); 70 // If the value was not yet lowered, lower its type so that it can 71 // be used in getPrimitiveTypeSizeInBits. 72 if (!mlir::isa<mlir::IntegerType>(valTy)) 73 valTy = convertType(valTy); 74 auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); 75 auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy); 76 if (toSize < fromSize) 77 return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val); 78 if (toSize > fromSize) 79 return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val); 80 return val; 81 } 82 83 fir::ConvertFIRToLLVMPattern::TypePair 84 ConvertFIRToLLVMPattern::getBoxTypePair(mlir::Type firBoxTy) const { 85 mlir::Type llvmBoxTy = 86 lowerTy().convertBoxTypeAsStruct(mlir::cast<fir::BaseBoxType>(firBoxTy)); 87 return TypePair{firBoxTy, llvmBoxTy}; 88 } 89 90 /// Construct code sequence to extract the specific value from a `fir.box`. 91 mlir::Value ConvertFIRToLLVMPattern::getValueFromBox( 92 mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Type resultTy, 93 mlir::ConversionPatternRewriter &rewriter, int boxValue) const { 94 if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) { 95 auto pty = getLlvmPtrType(resultTy.getContext()); 96 auto p = rewriter.create<mlir::LLVM::GEPOp>( 97 loc, pty, boxTy.llvm, box, 98 llvm::ArrayRef<mlir::LLVM::GEPArg>{0, boxValue}); 99 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, resultTy, p); 100 attachTBAATag(loadOp, boxTy.fir, nullptr, p); 101 return loadOp; 102 } 103 return rewriter.create<mlir::LLVM::ExtractValueOp>(loc, box, boxValue); 104 } 105 106 /// Method to construct code sequence to get the triple for dimension `dim` 107 /// from a box. 108 llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox( 109 mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy, 110 mlir::Value box, mlir::Value dim, 111 mlir::ConversionPatternRewriter &rewriter) const { 112 mlir::Value l0 = 113 loadDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter); 114 mlir::Value l1 = 115 loadDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter); 116 mlir::Value l2 = 117 loadDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter); 118 return {l0, l1, l2}; 119 } 120 121 llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox( 122 mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy, 123 mlir::Value box, int dim, mlir::ConversionPatternRewriter &rewriter) const { 124 mlir::Value l0 = 125 getDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter); 126 mlir::Value l1 = 127 getDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter); 128 mlir::Value l2 = 129 getDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter); 130 return {l0, l1, l2}; 131 } 132 133 mlir::Value ConvertFIRToLLVMPattern::loadDimFieldFromBox( 134 mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Value dim, 135 int off, mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const { 136 assert(mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType()) && 137 "descriptor inquiry with runtime dim can only be done on descriptor " 138 "in memory"); 139 mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0, 140 static_cast<int>(kDimsPosInBox), dim, off); 141 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p); 142 attachTBAATag(loadOp, boxTy.fir, nullptr, p); 143 return loadOp; 144 } 145 146 mlir::Value ConvertFIRToLLVMPattern::getDimFieldFromBox( 147 mlir::Location loc, TypePair boxTy, mlir::Value box, int dim, int off, 148 mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const { 149 if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) { 150 mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0, 151 static_cast<int>(kDimsPosInBox), dim, off); 152 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p); 153 attachTBAATag(loadOp, boxTy.fir, nullptr, p); 154 return loadOp; 155 } 156 return rewriter.create<mlir::LLVM::ExtractValueOp>( 157 loc, box, llvm::ArrayRef<std::int64_t>{kDimsPosInBox, dim, off}); 158 } 159 160 mlir::Value ConvertFIRToLLVMPattern::getStrideFromBox( 161 mlir::Location loc, TypePair boxTy, mlir::Value box, unsigned dim, 162 mlir::ConversionPatternRewriter &rewriter) const { 163 auto idxTy = lowerTy().indexType(); 164 return getDimFieldFromBox(loc, boxTy, box, dim, kDimStridePos, idxTy, 165 rewriter); 166 } 167 168 /// Read base address from a fir.box. Returned address has type ty. 169 mlir::Value ConvertFIRToLLVMPattern::getBaseAddrFromBox( 170 mlir::Location loc, TypePair boxTy, mlir::Value box, 171 mlir::ConversionPatternRewriter &rewriter) const { 172 mlir::Type resultTy = ::getLlvmPtrType(boxTy.llvm.getContext()); 173 return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kAddrPosInBox); 174 } 175 176 mlir::Value ConvertFIRToLLVMPattern::getElementSizeFromBox( 177 mlir::Location loc, mlir::Type resultTy, TypePair boxTy, mlir::Value box, 178 mlir::ConversionPatternRewriter &rewriter) const { 179 return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kElemLenPosInBox); 180 } 181 182 /// Read base address from a fir.box. Returned address has type ty. 183 mlir::Value ConvertFIRToLLVMPattern::getRankFromBox( 184 mlir::Location loc, TypePair boxTy, mlir::Value box, 185 mlir::ConversionPatternRewriter &rewriter) const { 186 mlir::Type resultTy = getBoxEleTy(boxTy.llvm, {kRankPosInBox}); 187 return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kRankPosInBox); 188 } 189 190 // Get the element type given an LLVM type that is of the form 191 // (array|struct|vector)+ and the provided indexes. 192 mlir::Type ConvertFIRToLLVMPattern::getBoxEleTy( 193 mlir::Type type, llvm::ArrayRef<std::int64_t> indexes) const { 194 for (unsigned i : indexes) { 195 if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMStructType>(type)) { 196 assert(!t.isOpaque() && i < t.getBody().size()); 197 type = t.getBody()[i]; 198 } else if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) { 199 type = t.getElementType(); 200 } else if (auto t = mlir::dyn_cast<mlir::VectorType>(type)) { 201 type = t.getElementType(); 202 } else { 203 fir::emitFatalError(mlir::UnknownLoc::get(type.getContext()), 204 "request for invalid box element type"); 205 } 206 } 207 return type; 208 } 209 210 // Return LLVM type of the object described by a fir.box of \p boxType. 211 mlir::Type ConvertFIRToLLVMPattern::getLlvmObjectTypeFromBoxType( 212 mlir::Type boxType) const { 213 mlir::Type objectType = fir::dyn_cast_ptrOrBoxEleTy(boxType); 214 assert(objectType && "boxType must be a box type"); 215 return this->convertType(objectType); 216 } 217 218 /// Read the address of the type descriptor from a box. 219 mlir::Value ConvertFIRToLLVMPattern::loadTypeDescAddress( 220 mlir::Location loc, TypePair boxTy, mlir::Value box, 221 mlir::ConversionPatternRewriter &rewriter) const { 222 unsigned typeDescFieldId = getTypeDescFieldId(boxTy.fir); 223 mlir::Type tdescType = lowerTy().convertTypeDescType(rewriter.getContext()); 224 return getValueFromBox(loc, boxTy, box, tdescType, rewriter, typeDescFieldId); 225 } 226 227 // Load the attribute from the \p box and perform a check against \p maskValue 228 // The final comparison is implemented as `(attribute & maskValue) != 0`. 229 mlir::Value ConvertFIRToLLVMPattern::genBoxAttributeCheck( 230 mlir::Location loc, TypePair boxTy, mlir::Value box, 231 mlir::ConversionPatternRewriter &rewriter, unsigned maskValue) const { 232 mlir::Type attrTy = rewriter.getI32Type(); 233 mlir::Value attribute = 234 getValueFromBox(loc, boxTy, box, attrTy, rewriter, kAttributePosInBox); 235 mlir::LLVM::ConstantOp attrMask = genConstantOffset(loc, rewriter, maskValue); 236 auto maskRes = 237 rewriter.create<mlir::LLVM::AndOp>(loc, attrTy, attribute, attrMask); 238 mlir::LLVM::ConstantOp c0 = genConstantOffset(loc, rewriter, 0); 239 return rewriter.create<mlir::LLVM::ICmpOp>(loc, mlir::LLVM::ICmpPredicate::ne, 240 maskRes, c0); 241 } 242 243 mlir::Value ConvertFIRToLLVMPattern::computeBoxSize( 244 mlir::Location loc, TypePair boxTy, mlir::Value box, 245 mlir::ConversionPatternRewriter &rewriter) const { 246 auto firBoxType = mlir::dyn_cast<fir::BaseBoxType>(boxTy.fir); 247 assert(firBoxType && "must be a BaseBoxType"); 248 const mlir::DataLayout &dl = lowerTy().getDataLayout(); 249 if (!firBoxType.isAssumedRank()) 250 return genConstantOffset(loc, rewriter, dl.getTypeSize(boxTy.llvm)); 251 fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0); 252 mlir::Type llvmScalarBoxType = 253 lowerTy().convertBoxTypeAsStruct(firScalarBoxType); 254 llvm::TypeSize scalarBoxSizeCst = dl.getTypeSize(llvmScalarBoxType); 255 mlir::Value scalarBoxSize = 256 genConstantOffset(loc, rewriter, scalarBoxSizeCst); 257 mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter); 258 mlir::Value rank = 259 integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank); 260 mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1}); 261 llvm::TypeSize sizePerDimCst = dl.getTypeSize(llvmDimsType); 262 assert((scalarBoxSizeCst + sizePerDimCst == 263 dl.getTypeSize(lowerTy().convertBoxTypeAsStruct( 264 firBoxType.getBoxTypeWithNewShape(1)))) && 265 "descriptor layout requires adding padding for dim field"); 266 mlir::Value sizePerDim = genConstantOffset(loc, rewriter, sizePerDimCst); 267 mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>( 268 loc, sizePerDim.getType(), sizePerDim, rank); 269 mlir::Value size = rewriter.create<mlir::LLVM::AddOp>( 270 loc, scalarBoxSize.getType(), scalarBoxSize, dimsSize); 271 return size; 272 } 273 274 // Find the Block in which the alloca should be inserted. 275 // The order to recursively find the proper block: 276 // 1. An OpenMP Op that will be outlined. 277 // 2. A LLVMFuncOp 278 // 3. The first ancestor that is an OpenMP Op or a LLVMFuncOp 279 mlir::Block * 280 ConvertFIRToLLVMPattern::getBlockForAllocaInsert(mlir::Operation *op) const { 281 if (auto iface = mlir::dyn_cast<mlir::omp::OutlineableOpenMPOpInterface>(op)) 282 return iface.getAllocaBlock(); 283 if (auto llvmFuncOp = mlir::dyn_cast<mlir::LLVM::LLVMFuncOp>(op)) 284 return &llvmFuncOp.front(); 285 286 return getBlockForAllocaInsert(op->getParentOp()); 287 } 288 289 // Generate an alloca of size 1 for an object of type \p llvmObjectTy in the 290 // allocation address space provided for the architecture in the DataLayout 291 // specification. If the address space is different from the devices 292 // program address space we perform a cast. In the case of most architectures 293 // the program and allocation address space will be the default of 0 and no 294 // cast will be emitted. 295 mlir::Value ConvertFIRToLLVMPattern::genAllocaAndAddrCastWithType( 296 mlir::Location loc, mlir::Type llvmObjectTy, unsigned alignment, 297 mlir::ConversionPatternRewriter &rewriter) const { 298 auto thisPt = rewriter.saveInsertionPoint(); 299 mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); 300 if (mlir::isa<mlir::omp::DeclareReductionOp>(parentOp) || 301 mlir::isa<mlir::omp::PrivateClauseOp>(parentOp)) { 302 // DeclareReductionOp & PrivateClauseOp have multiple child regions. We want 303 // to get the first block of whichever of those regions we are currently in 304 mlir::Region *parentRegion = rewriter.getInsertionBlock()->getParent(); 305 rewriter.setInsertionPointToStart(&parentRegion->front()); 306 } else { 307 mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp); 308 rewriter.setInsertionPointToStart(insertBlock); 309 } 310 auto size = genI32Constant(loc, rewriter, 1); 311 unsigned allocaAs = getAllocaAddressSpace(rewriter); 312 unsigned programAs = getProgramAddressSpace(rewriter); 313 314 mlir::Value al = rewriter.create<mlir::LLVM::AllocaOp>( 315 loc, ::getLlvmPtrType(llvmObjectTy.getContext(), allocaAs), llvmObjectTy, 316 size, alignment); 317 318 // if our allocation address space, is not the same as the program address 319 // space, then we must emit a cast to the program address space before use. 320 // An example case would be on AMDGPU, where the allocation address space is 321 // the numeric value 5 (private), and the program address space is 0 322 // (generic). 323 if (allocaAs != programAs) { 324 al = rewriter.create<mlir::LLVM::AddrSpaceCastOp>( 325 loc, ::getLlvmPtrType(llvmObjectTy.getContext(), programAs), al); 326 } 327 328 rewriter.restoreInsertionPoint(thisPt); 329 return al; 330 } 331 332 unsigned ConvertFIRToLLVMPattern::getAllocaAddressSpace( 333 mlir::ConversionPatternRewriter &rewriter) const { 334 mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); 335 assert(parentOp != nullptr && 336 "expected insertion block to have parent operation"); 337 if (auto module = parentOp->getParentOfType<mlir::ModuleOp>()) 338 if (mlir::Attribute addrSpace = 339 mlir::DataLayout(module).getAllocaMemorySpace()) 340 return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt(); 341 return defaultAddressSpace; 342 } 343 344 unsigned ConvertFIRToLLVMPattern::getProgramAddressSpace( 345 mlir::ConversionPatternRewriter &rewriter) const { 346 mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); 347 assert(parentOp != nullptr && 348 "expected insertion block to have parent operation"); 349 if (auto module = parentOp->getParentOfType<mlir::ModuleOp>()) 350 if (mlir::Attribute addrSpace = 351 mlir::DataLayout(module).getProgramMemorySpace()) 352 return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt(); 353 return defaultAddressSpace; 354 } 355 356 } // namespace fir 357