1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/CodeGen/FIROpPatterns.h" 14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 15 #include "llvm/Support/Debug.h" 16 17 static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context, 18 unsigned addressSpace = 0) { 19 return mlir::LLVM::LLVMPointerType::get(context, addressSpace); 20 } 21 22 static unsigned getTypeDescFieldId(mlir::Type ty) { 23 auto isArray = mlir::isa<fir::SequenceType>(fir::dyn_cast_ptrOrBoxEleTy(ty)); 24 return isArray ? kOptTypePtrPosInBox : kDimsPosInBox; 25 } 26 27 namespace fir { 28 29 ConvertFIRToLLVMPattern::ConvertFIRToLLVMPattern( 30 llvm::StringRef rootOpName, mlir::MLIRContext *context, 31 const fir::LLVMTypeConverter &typeConverter, 32 const fir::FIRToLLVMPassOptions &options, mlir::PatternBenefit benefit) 33 : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit), 34 options(options) {} 35 36 // Convert FIR type to LLVM without turning fir.box<T> into memory 37 // reference. 38 mlir::Type 39 ConvertFIRToLLVMPattern::convertObjectType(mlir::Type firType) const { 40 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType)) 41 return lowerTy().convertBoxTypeAsStruct(boxTy); 42 return lowerTy().convertType(firType); 43 } 44 45 mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genI32Constant( 46 mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, 47 int value) const { 48 mlir::Type i32Ty = rewriter.getI32Type(); 49 mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); 50 return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr); 51 } 52 53 mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genConstantOffset( 54 mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, 55 int offset) const { 56 mlir::Type ity = lowerTy().offsetType(); 57 mlir::IntegerAttr cattr = rewriter.getI32IntegerAttr(offset); 58 return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr); 59 } 60 61 /// Perform an extension or truncation as needed on an integer value. Lowering 62 /// to the specific target may involve some sign-extending or truncation of 63 /// values, particularly to fit them from abstract box types to the 64 /// appropriate reified structures. 65 mlir::Value ConvertFIRToLLVMPattern::integerCast( 66 mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, 67 mlir::Type ty, mlir::Value val, bool fold) const { 68 auto valTy = val.getType(); 69 // If the value was not yet lowered, lower its type so that it can 70 // be used in getPrimitiveTypeSizeInBits. 71 if (!mlir::isa<mlir::IntegerType>(valTy)) 72 valTy = convertType(valTy); 73 auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); 74 auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy); 75 if (fold) { 76 if (toSize < fromSize) 77 return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val); 78 if (toSize > fromSize) 79 return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val); 80 } else { 81 if (toSize < fromSize) 82 return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val); 83 if (toSize > fromSize) 84 return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val); 85 } 86 return val; 87 } 88 89 fir::ConvertFIRToLLVMPattern::TypePair 90 ConvertFIRToLLVMPattern::getBoxTypePair(mlir::Type firBoxTy) const { 91 mlir::Type llvmBoxTy = 92 lowerTy().convertBoxTypeAsStruct(mlir::cast<fir::BaseBoxType>(firBoxTy)); 93 return TypePair{firBoxTy, llvmBoxTy}; 94 } 95 96 /// Construct code sequence to extract the specific value from a `fir.box`. 97 mlir::Value ConvertFIRToLLVMPattern::getValueFromBox( 98 mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Type resultTy, 99 mlir::ConversionPatternRewriter &rewriter, int boxValue) const { 100 if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) { 101 auto pty = getLlvmPtrType(resultTy.getContext()); 102 auto p = rewriter.create<mlir::LLVM::GEPOp>( 103 loc, pty, boxTy.llvm, box, 104 llvm::ArrayRef<mlir::LLVM::GEPArg>{0, boxValue}); 105 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, resultTy, p); 106 attachTBAATag(loadOp, boxTy.fir, nullptr, p); 107 return loadOp; 108 } 109 return rewriter.create<mlir::LLVM::ExtractValueOp>(loc, box, boxValue); 110 } 111 112 /// Method to construct code sequence to get the triple for dimension `dim` 113 /// from a box. 114 llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox( 115 mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy, 116 mlir::Value box, mlir::Value dim, 117 mlir::ConversionPatternRewriter &rewriter) const { 118 mlir::Value l0 = 119 loadDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter); 120 mlir::Value l1 = 121 loadDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter); 122 mlir::Value l2 = 123 loadDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter); 124 return {l0, l1, l2}; 125 } 126 127 llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox( 128 mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy, 129 mlir::Value box, int dim, mlir::ConversionPatternRewriter &rewriter) const { 130 mlir::Value l0 = 131 getDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter); 132 mlir::Value l1 = 133 getDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter); 134 mlir::Value l2 = 135 getDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter); 136 return {l0, l1, l2}; 137 } 138 139 mlir::Value ConvertFIRToLLVMPattern::loadDimFieldFromBox( 140 mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Value dim, 141 int off, mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const { 142 assert(mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType()) && 143 "descriptor inquiry with runtime dim can only be done on descriptor " 144 "in memory"); 145 mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0, 146 static_cast<int>(kDimsPosInBox), dim, off); 147 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p); 148 attachTBAATag(loadOp, boxTy.fir, nullptr, p); 149 return loadOp; 150 } 151 152 mlir::Value ConvertFIRToLLVMPattern::getDimFieldFromBox( 153 mlir::Location loc, TypePair boxTy, mlir::Value box, int dim, int off, 154 mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const { 155 if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) { 156 mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0, 157 static_cast<int>(kDimsPosInBox), dim, off); 158 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p); 159 attachTBAATag(loadOp, boxTy.fir, nullptr, p); 160 return loadOp; 161 } 162 return rewriter.create<mlir::LLVM::ExtractValueOp>( 163 loc, box, llvm::ArrayRef<std::int64_t>{kDimsPosInBox, dim, off}); 164 } 165 166 mlir::Value ConvertFIRToLLVMPattern::getStrideFromBox( 167 mlir::Location loc, TypePair boxTy, mlir::Value box, unsigned dim, 168 mlir::ConversionPatternRewriter &rewriter) const { 169 auto idxTy = lowerTy().indexType(); 170 return getDimFieldFromBox(loc, boxTy, box, dim, kDimStridePos, idxTy, 171 rewriter); 172 } 173 174 /// Read base address from a fir.box. Returned address has type ty. 175 mlir::Value ConvertFIRToLLVMPattern::getBaseAddrFromBox( 176 mlir::Location loc, TypePair boxTy, mlir::Value box, 177 mlir::ConversionPatternRewriter &rewriter) const { 178 mlir::Type resultTy = ::getLlvmPtrType(boxTy.llvm.getContext()); 179 return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kAddrPosInBox); 180 } 181 182 mlir::Value ConvertFIRToLLVMPattern::getElementSizeFromBox( 183 mlir::Location loc, mlir::Type resultTy, TypePair boxTy, mlir::Value box, 184 mlir::ConversionPatternRewriter &rewriter) const { 185 return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kElemLenPosInBox); 186 } 187 188 /// Read base address from a fir.box. Returned address has type ty. 189 mlir::Value ConvertFIRToLLVMPattern::getRankFromBox( 190 mlir::Location loc, TypePair boxTy, mlir::Value box, 191 mlir::ConversionPatternRewriter &rewriter) const { 192 mlir::Type resultTy = getBoxEleTy(boxTy.llvm, {kRankPosInBox}); 193 return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kRankPosInBox); 194 } 195 196 // Get the element type given an LLVM type that is of the form 197 // (array|struct|vector)+ and the provided indexes. 198 mlir::Type ConvertFIRToLLVMPattern::getBoxEleTy( 199 mlir::Type type, llvm::ArrayRef<std::int64_t> indexes) const { 200 for (unsigned i : indexes) { 201 if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMStructType>(type)) { 202 assert(!t.isOpaque() && i < t.getBody().size()); 203 type = t.getBody()[i]; 204 } else if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) { 205 type = t.getElementType(); 206 } else if (auto t = mlir::dyn_cast<mlir::VectorType>(type)) { 207 type = t.getElementType(); 208 } else { 209 fir::emitFatalError(mlir::UnknownLoc::get(type.getContext()), 210 "request for invalid box element type"); 211 } 212 } 213 return type; 214 } 215 216 // Return LLVM type of the object described by a fir.box of \p boxType. 217 mlir::Type ConvertFIRToLLVMPattern::getLlvmObjectTypeFromBoxType( 218 mlir::Type boxType) const { 219 mlir::Type objectType = fir::dyn_cast_ptrOrBoxEleTy(boxType); 220 assert(objectType && "boxType must be a box type"); 221 return this->convertType(objectType); 222 } 223 224 /// Read the address of the type descriptor from a box. 225 mlir::Value ConvertFIRToLLVMPattern::loadTypeDescAddress( 226 mlir::Location loc, TypePair boxTy, mlir::Value box, 227 mlir::ConversionPatternRewriter &rewriter) const { 228 unsigned typeDescFieldId = getTypeDescFieldId(boxTy.fir); 229 mlir::Type tdescType = lowerTy().convertTypeDescType(rewriter.getContext()); 230 return getValueFromBox(loc, boxTy, box, tdescType, rewriter, typeDescFieldId); 231 } 232 233 // Load the attribute from the \p box and perform a check against \p maskValue 234 // The final comparison is implemented as `(attribute & maskValue) != 0`. 235 mlir::Value ConvertFIRToLLVMPattern::genBoxAttributeCheck( 236 mlir::Location loc, TypePair boxTy, mlir::Value box, 237 mlir::ConversionPatternRewriter &rewriter, unsigned maskValue) const { 238 mlir::Type attrTy = rewriter.getI32Type(); 239 mlir::Value attribute = 240 getValueFromBox(loc, boxTy, box, attrTy, rewriter, kAttributePosInBox); 241 mlir::LLVM::ConstantOp attrMask = genConstantOffset(loc, rewriter, maskValue); 242 auto maskRes = 243 rewriter.create<mlir::LLVM::AndOp>(loc, attrTy, attribute, attrMask); 244 mlir::LLVM::ConstantOp c0 = genConstantOffset(loc, rewriter, 0); 245 return rewriter.create<mlir::LLVM::ICmpOp>(loc, mlir::LLVM::ICmpPredicate::ne, 246 maskRes, c0); 247 } 248 249 mlir::Value ConvertFIRToLLVMPattern::computeBoxSize( 250 mlir::Location loc, TypePair boxTy, mlir::Value box, 251 mlir::ConversionPatternRewriter &rewriter) const { 252 auto firBoxType = mlir::dyn_cast<fir::BaseBoxType>(boxTy.fir); 253 assert(firBoxType && "must be a BaseBoxType"); 254 const mlir::DataLayout &dl = lowerTy().getDataLayout(); 255 if (!firBoxType.isAssumedRank()) 256 return genConstantOffset(loc, rewriter, dl.getTypeSize(boxTy.llvm)); 257 fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0); 258 mlir::Type llvmScalarBoxType = 259 lowerTy().convertBoxTypeAsStruct(firScalarBoxType); 260 llvm::TypeSize scalarBoxSizeCst = dl.getTypeSize(llvmScalarBoxType); 261 mlir::Value scalarBoxSize = 262 genConstantOffset(loc, rewriter, scalarBoxSizeCst); 263 mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter); 264 mlir::Value rank = 265 integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank); 266 mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1}); 267 llvm::TypeSize sizePerDimCst = dl.getTypeSize(llvmDimsType); 268 assert((scalarBoxSizeCst + sizePerDimCst == 269 dl.getTypeSize(lowerTy().convertBoxTypeAsStruct( 270 firBoxType.getBoxTypeWithNewShape(1)))) && 271 "descriptor layout requires adding padding for dim field"); 272 mlir::Value sizePerDim = genConstantOffset(loc, rewriter, sizePerDimCst); 273 mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>( 274 loc, sizePerDim.getType(), sizePerDim, rank); 275 mlir::Value size = rewriter.create<mlir::LLVM::AddOp>( 276 loc, scalarBoxSize.getType(), scalarBoxSize, dimsSize); 277 return size; 278 } 279 280 // Find the Block in which the alloca should be inserted. 281 // The order to recursively find the proper block: 282 // 1. An OpenMP Op that will be outlined. 283 // 2. An OpenMP or OpenACC Op with one or more regions holding executable code. 284 // 3. A LLVMFuncOp 285 // 4. The first ancestor that is one of the above. 286 mlir::Block *ConvertFIRToLLVMPattern::getBlockForAllocaInsert( 287 mlir::Operation *op, mlir::Region *parentRegion) const { 288 if (auto iface = mlir::dyn_cast<mlir::omp::OutlineableOpenMPOpInterface>(op)) 289 return iface.getAllocaBlock(); 290 if (auto recipeIface = mlir::dyn_cast<mlir::accomp::RecipeInterface>(op)) 291 return recipeIface.getAllocaBlock(*parentRegion); 292 if (auto llvmFuncOp = mlir::dyn_cast<mlir::LLVM::LLVMFuncOp>(op)) 293 return &llvmFuncOp.front(); 294 295 return getBlockForAllocaInsert(op->getParentOp(), parentRegion); 296 } 297 298 // Generate an alloca of size 1 for an object of type \p llvmObjectTy in the 299 // allocation address space provided for the architecture in the DataLayout 300 // specification. If the address space is different from the devices 301 // program address space we perform a cast. In the case of most architectures 302 // the program and allocation address space will be the default of 0 and no 303 // cast will be emitted. 304 mlir::Value ConvertFIRToLLVMPattern::genAllocaAndAddrCastWithType( 305 mlir::Location loc, mlir::Type llvmObjectTy, unsigned alignment, 306 mlir::ConversionPatternRewriter &rewriter) const { 307 auto thisPt = rewriter.saveInsertionPoint(); 308 mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); 309 mlir::Region *parentRegion = rewriter.getInsertionBlock()->getParent(); 310 mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp, parentRegion); 311 rewriter.setInsertionPointToStart(insertBlock); 312 auto size = genI32Constant(loc, rewriter, 1); 313 unsigned allocaAs = getAllocaAddressSpace(rewriter); 314 unsigned programAs = getProgramAddressSpace(rewriter); 315 316 mlir::Value al = rewriter.create<mlir::LLVM::AllocaOp>( 317 loc, ::getLlvmPtrType(llvmObjectTy.getContext(), allocaAs), llvmObjectTy, 318 size, alignment); 319 320 // if our allocation address space, is not the same as the program address 321 // space, then we must emit a cast to the program address space before use. 322 // An example case would be on AMDGPU, where the allocation address space is 323 // the numeric value 5 (private), and the program address space is 0 324 // (generic). 325 if (allocaAs != programAs) { 326 al = rewriter.create<mlir::LLVM::AddrSpaceCastOp>( 327 loc, ::getLlvmPtrType(llvmObjectTy.getContext(), programAs), al); 328 } 329 330 rewriter.restoreInsertionPoint(thisPt); 331 return al; 332 } 333 334 unsigned ConvertFIRToLLVMPattern::getAllocaAddressSpace( 335 mlir::ConversionPatternRewriter &rewriter) const { 336 mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); 337 assert(parentOp != nullptr && 338 "expected insertion block to have parent operation"); 339 if (auto module = parentOp->getParentOfType<mlir::ModuleOp>()) 340 if (mlir::Attribute addrSpace = 341 mlir::DataLayout(module).getAllocaMemorySpace()) 342 return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt(); 343 return defaultAddressSpace; 344 } 345 346 unsigned ConvertFIRToLLVMPattern::getProgramAddressSpace( 347 mlir::ConversionPatternRewriter &rewriter) const { 348 mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); 349 assert(parentOp != nullptr && 350 "expected insertion block to have parent operation"); 351 if (auto module = parentOp->getParentOfType<mlir::ModuleOp>()) 352 if (mlir::Attribute addrSpace = 353 mlir::DataLayout(module).getProgramMemorySpace()) 354 return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt(); 355 return defaultAddressSpace; 356 } 357 358 } // namespace fir 359