1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 10 #include "MemRefDescriptor.h" 11 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 14 #include "llvm/ADT/ScopeExit.h" 15 #include "llvm/Support/Threading.h" 16 #include <memory> 17 #include <mutex> 18 #include <optional> 19 20 using namespace mlir; 21 22 SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() { 23 { 24 // Most of the time, the entry already exists in the map. 25 std::shared_lock<decltype(callStackMutex)> lock(callStackMutex, 26 std::defer_lock); 27 if (getContext().isMultithreadingEnabled()) 28 lock.lock(); 29 auto recursiveStack = conversionCallStack.find(llvm::get_threadid()); 30 if (recursiveStack != conversionCallStack.end()) 31 return *recursiveStack->second; 32 } 33 34 // First time this thread gets here, we have to get an exclusive access to 35 // inset in the map 36 std::unique_lock<decltype(callStackMutex)> lock(callStackMutex); 37 auto recursiveStackInserted = conversionCallStack.insert(std::make_pair( 38 llvm::get_threadid(), std::make_unique<SmallVector<Type>>())); 39 return *recursiveStackInserted.first->second; 40 } 41 42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions. 43 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, 44 const DataLayoutAnalysis *analysis) 45 : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {} 46 47 /// Helper function that checks if the given value range is a bare pointer. 48 static bool isBarePointer(ValueRange values) { 49 return values.size() == 1 && 50 isa<LLVM::LLVMPointerType>(values.front().getType()); 51 } 52 53 /// Pack SSA values into an unranked memref descriptor struct. 54 static Value packUnrankedMemRefDesc(OpBuilder &builder, 55 UnrankedMemRefType resultType, 56 ValueRange inputs, Location loc, 57 const LLVMTypeConverter &converter) { 58 // Note: Bare pointers are not supported for unranked memrefs because a 59 // memref descriptor cannot be built just from a bare pointer. 60 if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields()) 61 return Value(); 62 return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType, 63 inputs); 64 } 65 66 /// Pack SSA values into a ranked memref descriptor struct. 67 static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, 68 ValueRange inputs, Location loc, 69 const LLVMTypeConverter &converter) { 70 assert(resultType && "expected non-null result type"); 71 if (isBarePointer(inputs)) 72 return MemRefDescriptor::fromStaticShape(builder, loc, converter, 73 resultType, inputs[0]); 74 if (TypeRange(inputs) == 75 converter.getMemRefDescriptorFields(resultType, 76 /*unpackAggregates=*/true)) 77 return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs); 78 // The inputs are neither a bare pointer nor an unpacked memref descriptor. 79 // This materialization function cannot be used. 80 return Value(); 81 } 82 83 /// MemRef descriptor elements -> UnrankedMemRefType 84 static Value unrankedMemRefMaterialization(OpBuilder &builder, 85 UnrankedMemRefType resultType, 86 ValueRange inputs, Location loc, 87 const LLVMTypeConverter &converter) { 88 // A source materialization must return a value of type 89 // `resultType`, so insert a cast from the memref descriptor type 90 // (!llvm.struct) to the original memref type. 91 Value packed = 92 packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter); 93 if (!packed) 94 return Value(); 95 return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed) 96 .getResult(0); 97 } 98 99 /// MemRef descriptor elements -> MemRefType 100 static Value rankedMemRefMaterialization(OpBuilder &builder, 101 MemRefType resultType, 102 ValueRange inputs, Location loc, 103 const LLVMTypeConverter &converter) { 104 // A source materialization must return a value of type `resultType`, 105 // so insert a cast from the memref descriptor type (!llvm.struct) to the 106 // original memref type. 107 Value packed = 108 packRankedMemRefDesc(builder, resultType, inputs, loc, converter); 109 if (!packed) 110 return Value(); 111 return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed) 112 .getResult(0); 113 } 114 115 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. 116 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, 117 const LowerToLLVMOptions &options, 118 const DataLayoutAnalysis *analysis) 119 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options), 120 dataLayoutAnalysis(analysis) { 121 assert(llvmDialect && "LLVM IR dialect is not registered"); 122 123 // Register conversions for the builtin types. 124 addConversion([&](ComplexType type) { return convertComplexType(type); }); 125 addConversion([&](FloatType type) { return convertFloatType(type); }); 126 addConversion([&](FunctionType type) { return convertFunctionType(type); }); 127 addConversion([&](IndexType type) { return convertIndexType(type); }); 128 addConversion([&](IntegerType type) { return convertIntegerType(type); }); 129 addConversion([&](MemRefType type) { return convertMemRefType(type); }); 130 addConversion( 131 [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); 132 addConversion([&](VectorType type) -> std::optional<Type> { 133 FailureOr<Type> llvmType = convertVectorType(type); 134 if (failed(llvmType)) 135 return std::nullopt; 136 return llvmType; 137 }); 138 139 // LLVM-compatible types are legal, so add a pass-through conversion. Do this 140 // before the conversions below since conversions are attempted in reverse 141 // order and those should take priority. 142 addConversion([](Type type) { 143 return LLVM::isCompatibleType(type) ? std::optional<Type>(type) 144 : std::nullopt; 145 }); 146 147 addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results) 148 -> std::optional<LogicalResult> { 149 // Fastpath for types that won't be converted by this callback anyway. 150 if (LLVM::isCompatibleType(type)) { 151 results.push_back(type); 152 return success(); 153 } 154 155 if (type.isIdentified()) { 156 auto convertedType = LLVM::LLVMStructType::getIdentified( 157 type.getContext(), ("_Converted." + type.getName()).str()); 158 159 SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack(); 160 if (llvm::count(recursiveStack, type)) { 161 results.push_back(convertedType); 162 return success(); 163 } 164 recursiveStack.push_back(type); 165 auto popConversionCallStack = llvm::make_scope_exit( 166 [&recursiveStack]() { recursiveStack.pop_back(); }); 167 168 SmallVector<Type> convertedElemTypes; 169 convertedElemTypes.reserve(type.getBody().size()); 170 if (failed(convertTypes(type.getBody(), convertedElemTypes))) 171 return std::nullopt; 172 173 // If the converted type has not been initialized yet, just set its body 174 // to be the converted arguments and return. 175 if (!convertedType.isInitialized()) { 176 if (failed( 177 convertedType.setBody(convertedElemTypes, type.isPacked()))) { 178 return failure(); 179 } 180 results.push_back(convertedType); 181 return success(); 182 } 183 184 // If it has been initialized, has the same body and packed bit, just use 185 // it. This ensures that recursive structs keep being recursive rather 186 // than including a non-updated name. 187 if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) && 188 convertedType.isPacked() == type.isPacked()) { 189 results.push_back(convertedType); 190 return success(); 191 } 192 193 return failure(); 194 } 195 196 SmallVector<Type> convertedSubtypes; 197 convertedSubtypes.reserve(type.getBody().size()); 198 if (failed(convertTypes(type.getBody(), convertedSubtypes))) 199 return std::nullopt; 200 201 results.push_back(LLVM::LLVMStructType::getLiteral( 202 type.getContext(), convertedSubtypes, type.isPacked())); 203 return success(); 204 }); 205 addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> { 206 if (auto element = convertType(type.getElementType())) 207 return LLVM::LLVMArrayType::get(element, type.getNumElements()); 208 return std::nullopt; 209 }); 210 addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> { 211 Type convertedResType = convertType(type.getReturnType()); 212 if (!convertedResType) 213 return std::nullopt; 214 215 SmallVector<Type> convertedArgTypes; 216 convertedArgTypes.reserve(type.getNumParams()); 217 if (failed(convertTypes(type.getParams(), convertedArgTypes))) 218 return std::nullopt; 219 220 return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes, 221 type.isVarArg()); 222 }); 223 224 // Add generic source and target materializations to handle cases where 225 // non-LLVM types persist after an LLVM conversion. 226 addSourceMaterialization([&](OpBuilder &builder, Type resultType, 227 ValueRange inputs, Location loc) { 228 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) 229 .getResult(0); 230 }); 231 addTargetMaterialization([&](OpBuilder &builder, Type resultType, 232 ValueRange inputs, Location loc) { 233 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) 234 .getResult(0); 235 }); 236 237 // Source materializations convert from the new block argument types 238 // (multiple SSA values that make up a memref descriptor) back to the 239 // original block argument type. 240 addSourceMaterialization([&](OpBuilder &builder, 241 UnrankedMemRefType resultType, ValueRange inputs, 242 Location loc) { 243 return unrankedMemRefMaterialization(builder, resultType, inputs, loc, 244 *this); 245 }); 246 addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType, 247 ValueRange inputs, Location loc) { 248 return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this); 249 }); 250 251 // Bare pointer -> Packed MemRef descriptor 252 addTargetMaterialization([&](OpBuilder &builder, Type resultType, 253 ValueRange inputs, Location loc, 254 Type originalType) -> Value { 255 // The original MemRef type is required to build a MemRef descriptor 256 // because the sizes/strides of the MemRef cannot be inferred from just the 257 // bare pointer. 258 if (!originalType) 259 return Value(); 260 if (resultType != convertType(originalType)) 261 return Value(); 262 if (auto memrefType = dyn_cast<MemRefType>(originalType)) 263 return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this); 264 if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType)) 265 return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc, 266 *this); 267 return Value(); 268 }); 269 270 // Integer memory spaces map to themselves. 271 addTypeAttributeConversion( 272 [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; }); 273 } 274 275 /// Returns the MLIR context. 276 MLIRContext &LLVMTypeConverter::getContext() const { 277 return *getDialect()->getContext(); 278 } 279 280 Type LLVMTypeConverter::getIndexType() const { 281 return IntegerType::get(&getContext(), getIndexTypeBitwidth()); 282 } 283 284 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const { 285 return options.dataLayout.getPointerSizeInBits(addressSpace); 286 } 287 288 Type LLVMTypeConverter::convertIndexType(IndexType type) const { 289 return getIndexType(); 290 } 291 292 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const { 293 return IntegerType::get(&getContext(), type.getWidth()); 294 } 295 296 Type LLVMTypeConverter::convertFloatType(FloatType type) const { 297 // Valid LLVM float types are used directly. 298 if (LLVM::isCompatibleType(type)) 299 return type; 300 301 // F4, F6, F8 types are converted to integer types with the same bit width. 302 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, 303 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, 304 Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType, 305 Float8E8M0FNUType>(type)) 306 return IntegerType::get(&getContext(), type.getWidth()); 307 308 // Other floating-point types: A custom type conversion rule must be 309 // specified by the user. 310 return Type(); 311 } 312 313 // Convert a `ComplexType` to an LLVM type. The result is a complex number 314 // struct with entries for the 315 // 1. real part and for the 316 // 2. imaginary part. 317 Type LLVMTypeConverter::convertComplexType(ComplexType type) const { 318 auto elementType = convertType(type.getElementType()); 319 return LLVM::LLVMStructType::getLiteral(&getContext(), 320 {elementType, elementType}); 321 } 322 323 // Except for signatures, MLIR function types are converted into LLVM 324 // pointer-to-function types. 325 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const { 326 return LLVM::LLVMPointerType::get(type.getContext()); 327 } 328 329 /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the 330 /// function arguments. Returns an empty container if none of these attributes 331 /// are found in any of the arguments. 332 static void 333 filterByValRefArgAttrs(FunctionOpInterface funcOp, 334 SmallVectorImpl<std::optional<NamedAttribute>> &result) { 335 assert(result.empty() && "Unexpected non-empty output"); 336 result.resize(funcOp.getNumArguments(), std::nullopt); 337 bool foundByValByRefAttrs = false; 338 for (int argIdx : llvm::seq(funcOp.getNumArguments())) { 339 for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) { 340 if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() || 341 namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) { 342 foundByValByRefAttrs = true; 343 result[argIdx] = namedAttr; 344 break; 345 } 346 } 347 } 348 349 if (!foundByValByRefAttrs) 350 result.clear(); 351 } 352 353 // Function types are converted to LLVM Function types by recursively converting 354 // argument and result types. If MLIR Function has zero results, the LLVM 355 // Function has one VoidType result. If MLIR Function has more than one result, 356 // they are into an LLVM StructType in their order of appearance. 357 // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and 358 // `llvm.byref` function arguments which are not LLVM pointers are overridden 359 // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already 360 // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`. 361 Type LLVMTypeConverter::convertFunctionSignatureImpl( 362 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, 363 LLVMTypeConverter::SignatureConversion &result, 364 SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const { 365 // Select the argument converter depending on the calling convention. 366 useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv; 367 auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter 368 : structFuncArgTypeConverter; 369 // Convert argument types one by one and check for errors. 370 for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) { 371 SmallVector<Type, 8> converted; 372 if (failed(funcArgConverter(*this, type, converted))) 373 return {}; 374 375 // Rewrite converted type of `llvm.byval` or `llvm.byref` function 376 // argument that was not converted to an LLVM pointer types. 377 if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() && 378 converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) { 379 // If the argument was already converted to an LLVM pointer type, we stop 380 // tracking it as it doesn't need more processing. 381 if (isa<LLVM::LLVMPointerType>(converted[0])) 382 (*byValRefNonPtrAttrs)[idx] = std::nullopt; 383 else 384 converted[0] = LLVM::LLVMPointerType::get(&getContext()); 385 } 386 387 result.addInputs(idx, converted); 388 } 389 390 // If function does not return anything, create the void result type, 391 // if it returns on element, convert it, otherwise pack the result types into 392 // a struct. 393 Type resultType = 394 funcTy.getNumResults() == 0 395 ? LLVM::LLVMVoidType::get(&getContext()) 396 : packFunctionResults(funcTy.getResults(), useBarePtrCallConv); 397 if (!resultType) 398 return {}; 399 return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(), 400 isVariadic); 401 } 402 403 Type LLVMTypeConverter::convertFunctionSignature( 404 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, 405 LLVMTypeConverter::SignatureConversion &result) const { 406 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv, 407 result, 408 /*byValRefNonPtrAttrs=*/nullptr); 409 } 410 411 Type LLVMTypeConverter::convertFunctionSignature( 412 FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv, 413 LLVMTypeConverter::SignatureConversion &result, 414 SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const { 415 // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those 416 // that were not converted to LLVM pointer types will be returned for further 417 // processing. 418 filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs); 419 auto funcTy = cast<FunctionType>(funcOp.getFunctionType()); 420 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv, 421 result, &byValRefNonPtrAttrs); 422 } 423 424 /// Converts the function type to a C-compatible format, in particular using 425 /// pointers to memref descriptors for arguments. 426 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType> 427 LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const { 428 SmallVector<Type, 4> inputs; 429 430 Type resultType = type.getNumResults() == 0 431 ? LLVM::LLVMVoidType::get(&getContext()) 432 : packFunctionResults(type.getResults()); 433 if (!resultType) 434 return {}; 435 436 auto ptrType = LLVM::LLVMPointerType::get(type.getContext()); 437 auto structType = dyn_cast<LLVM::LLVMStructType>(resultType); 438 if (structType) { 439 // Struct types cannot be safely returned via C interface. Make this a 440 // pointer argument, instead. 441 inputs.push_back(ptrType); 442 resultType = LLVM::LLVMVoidType::get(&getContext()); 443 } 444 445 for (Type t : type.getInputs()) { 446 auto converted = convertType(t); 447 if (!converted || !LLVM::isCompatibleType(converted)) 448 return {}; 449 if (isa<MemRefType, UnrankedMemRefType>(t)) 450 converted = ptrType; 451 inputs.push_back(converted); 452 } 453 454 return {LLVM::LLVMFunctionType::get(resultType, inputs), structType}; 455 } 456 457 /// Convert a memref type into a list of LLVM IR types that will form the 458 /// memref descriptor. The result contains the following types: 459 /// 1. The pointer to the allocated data buffer, followed by 460 /// 2. The pointer to the aligned data buffer, followed by 461 /// 3. A lowered `index`-type integer containing the distance between the 462 /// beginning of the buffer and the first element to be accessed through the 463 /// view, followed by 464 /// 4. An array containing as many `index`-type integers as the rank of the 465 /// MemRef: the array represents the size, in number of elements, of the memref 466 /// along the given dimension. For constant MemRef dimensions, the 467 /// corresponding size entry is a constant whose runtime value must match the 468 /// static value, followed by 469 /// 5. A second array containing as many `index`-type integers as the rank of 470 /// the MemRef: the second array represents the "stride" (in tensor abstraction 471 /// sense), i.e. the number of consecutive elements of the underlying buffer. 472 /// TODO: add assertions for the static cases. 473 /// 474 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5) 475 /// are expanded into individual index-type elements. 476 /// 477 /// template <typename Elem, typename Index, size_t Rank> 478 /// struct { 479 /// Elem *allocatedPtr; 480 /// Elem *alignedPtr; 481 /// Index offset; 482 /// Index sizes[Rank]; // omitted when rank == 0 483 /// Index strides[Rank]; // omitted when rank == 0 484 /// }; 485 SmallVector<Type, 5> 486 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, 487 bool unpackAggregates) const { 488 if (!type.isStrided()) { 489 emitError( 490 UnknownLoc::get(type.getContext()), 491 "conversion to strided form failed either due to non-strided layout " 492 "maps (which should have been normalized away) or other reasons"); 493 return {}; 494 } 495 496 Type elementType = convertType(type.getElementType()); 497 if (!elementType) 498 return {}; 499 500 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type); 501 if (failed(addressSpace)) { 502 emitError(UnknownLoc::get(type.getContext()), 503 "conversion of memref memory space ") 504 << type.getMemorySpace() 505 << " to integer address space " 506 "failed. Consider adding memory space conversions."; 507 return {}; 508 } 509 auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); 510 511 auto indexTy = getIndexType(); 512 513 SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy}; 514 auto rank = type.getRank(); 515 if (rank == 0) 516 return results; 517 518 if (unpackAggregates) 519 results.insert(results.end(), 2 * rank, indexTy); 520 else 521 results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank)); 522 return results; 523 } 524 525 unsigned 526 LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, 527 const DataLayout &layout) const { 528 // Compute the descriptor size given that of its components indicated above. 529 unsigned space = *getMemRefAddressSpace(type); 530 return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) + 531 (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType()); 532 } 533 534 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that 535 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`. 536 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const { 537 // When converting a MemRefType to a struct with descriptor fields, do not 538 // unpack the `sizes` and `strides` arrays. 539 SmallVector<Type, 5> types = 540 getMemRefDescriptorFields(type, /*unpackAggregates=*/false); 541 if (types.empty()) 542 return {}; 543 return LLVM::LLVMStructType::getLiteral(&getContext(), types); 544 } 545 546 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types 547 /// that will form the unranked memref descriptor. In particular, the fields 548 /// for an unranked memref descriptor are: 549 /// 1. index-typed rank, the dynamic rank of this MemRef 550 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be 551 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to 552 /// be unranked. 553 SmallVector<Type, 2> 554 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const { 555 return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())}; 556 } 557 558 unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize( 559 UnrankedMemRefType type, const DataLayout &layout) const { 560 // Compute the descriptor size given that of its components indicated above. 561 unsigned space = *getMemRefAddressSpace(type); 562 return layout.getTypeSize(getIndexType()) + 563 llvm::divideCeil(getPointerBitwidth(space), 8); 564 } 565 566 Type LLVMTypeConverter::convertUnrankedMemRefType( 567 UnrankedMemRefType type) const { 568 if (!convertType(type.getElementType())) 569 return {}; 570 return LLVM::LLVMStructType::getLiteral(&getContext(), 571 getUnrankedMemRefDescriptorFields()); 572 } 573 574 FailureOr<unsigned> 575 LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const { 576 if (!type.getMemorySpace()) // Default memory space -> 0. 577 return 0; 578 std::optional<Attribute> converted = 579 convertTypeAttribute(type, type.getMemorySpace()); 580 if (!converted) 581 return failure(); 582 if (!(*converted)) // Conversion to default is 0. 583 return 0; 584 if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) { 585 if (explicitSpace.getType().isIndex() || 586 explicitSpace.getType().isSignlessInteger()) 587 return explicitSpace.getInt(); 588 } 589 return failure(); 590 } 591 592 // Check if a memref type can be converted to a bare pointer. 593 bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) { 594 if (isa<UnrankedMemRefType>(type)) 595 // Unranked memref is not supported in the bare pointer calling convention. 596 return false; 597 598 // Check that the memref has static shape, strides and offset. Otherwise, it 599 // cannot be lowered to a bare pointer. 600 auto memrefTy = cast<MemRefType>(type); 601 if (!memrefTy.hasStaticShape()) 602 return false; 603 604 int64_t offset = 0; 605 SmallVector<int64_t, 4> strides; 606 if (failed(memrefTy.getStridesAndOffset(strides, offset))) 607 return false; 608 609 for (int64_t stride : strides) 610 if (ShapedType::isDynamic(stride)) 611 return false; 612 613 return !ShapedType::isDynamic(offset); 614 } 615 616 /// Convert a memref type to a bare pointer to the memref element type. 617 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const { 618 if (!canConvertToBarePtr(type)) 619 return {}; 620 Type elementType = convertType(type.getElementType()); 621 if (!elementType) 622 return {}; 623 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type); 624 if (failed(addressSpace)) 625 return {}; 626 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); 627 } 628 629 /// Convert an n-D vector type to an LLVM vector type: 630 /// * 0-D `vector<T>` are converted to vector<1xT> 631 /// * 1-D `vector<axT>` remains as is while, 632 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to 633 /// `!llvm.array<ax...array<jxvector<kxT>>>`. 634 /// As LLVM supports arrays of scalable vectors, this method will also convert 635 /// n-D scalable vectors provided that only the trailing dim is scalable. 636 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const { 637 auto elementType = convertType(type.getElementType()); 638 if (!elementType) 639 return {}; 640 if (type.getShape().empty()) 641 return VectorType::get({1}, elementType); 642 Type vectorType = VectorType::get(type.getShape().back(), elementType, 643 type.getScalableDims().back()); 644 assert(LLVM::isCompatibleVectorType(vectorType) && 645 "expected vector type compatible with the LLVM dialect"); 646 // For n-D vector types for which a _non-trailing_ dim is scalable, 647 // return a failure. Supporting such cases would require LLVM 648 // to support something akin "scalable arrays" of vectors. 649 if (llvm::is_contained(type.getScalableDims().drop_back(), true)) 650 return failure(); 651 auto shape = type.getShape(); 652 for (int i = shape.size() - 2; i >= 0; --i) 653 vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); 654 return vectorType; 655 } 656 657 /// Convert a type in the context of the default or bare pointer calling 658 /// convention. Calling convention sensitive types, such as MemRefType and 659 /// UnrankedMemRefType, are converted following the specific rules for the 660 /// calling convention. Calling convention independent types are converted 661 /// following the default LLVM type conversions. 662 Type LLVMTypeConverter::convertCallingConventionType( 663 Type type, bool useBarePtrCallConv) const { 664 if (useBarePtrCallConv) 665 if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) 666 return convertMemRefToBarePtr(memrefTy); 667 668 return convertType(type); 669 } 670 671 /// Promote the bare pointers in 'values' that resulted from memrefs to 672 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion 673 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). 674 void LLVMTypeConverter::promoteBarePtrsToDescriptors( 675 ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes, 676 SmallVectorImpl<Value> &values) const { 677 assert(stdTypes.size() == values.size() && 678 "The number of types and values doesn't match"); 679 for (unsigned i = 0, end = values.size(); i < end; ++i) 680 if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i])) 681 values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, 682 memrefTy, values[i]); 683 } 684 685 /// Convert a non-empty list of types of values produced by an operation into an 686 /// LLVM-compatible type. In particular, if more than one value is 687 /// produced, create a literal structure with elements that correspond to each 688 /// of the types converted with `convertType`. 689 Type LLVMTypeConverter::packOperationResults(TypeRange types) const { 690 assert(!types.empty() && "expected non-empty list of type"); 691 if (types.size() == 1) 692 return convertType(types[0]); 693 694 SmallVector<Type> resultTypes; 695 resultTypes.reserve(types.size()); 696 for (Type type : types) { 697 Type converted = convertType(type); 698 if (!converted || !LLVM::isCompatibleType(converted)) 699 return {}; 700 resultTypes.push_back(converted); 701 } 702 703 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); 704 } 705 706 /// Convert a non-empty list of types to be returned from a function into an 707 /// LLVM-compatible type. In particular, if more than one value is returned, 708 /// create an LLVM dialect structure type with elements that correspond to each 709 /// of the types converted with `convertCallingConventionType`. 710 Type LLVMTypeConverter::packFunctionResults(TypeRange types, 711 bool useBarePtrCallConv) const { 712 assert(!types.empty() && "expected non-empty list of type"); 713 714 useBarePtrCallConv |= options.useBarePtrCallConv; 715 if (types.size() == 1) 716 return convertCallingConventionType(types.front(), useBarePtrCallConv); 717 718 SmallVector<Type> resultTypes; 719 resultTypes.reserve(types.size()); 720 for (auto t : types) { 721 auto converted = convertCallingConventionType(t, useBarePtrCallConv); 722 if (!converted || !LLVM::isCompatibleType(converted)) 723 return {}; 724 resultTypes.push_back(converted); 725 } 726 727 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); 728 } 729 730 Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, 731 OpBuilder &builder) const { 732 // Alloca with proper alignment. We do not expect optimizations of this 733 // alloca op and so we omit allocating at the entry block. 734 auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); 735 Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), 736 builder.getIndexAttr(1)); 737 Value allocated = 738 builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one); 739 // Store into the alloca'ed descriptor. 740 builder.create<LLVM::StoreOp>(loc, operand, allocated); 741 return allocated; 742 } 743 744 SmallVector<Value, 4> 745 LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, 746 ValueRange operands, OpBuilder &builder, 747 bool useBarePtrCallConv) const { 748 SmallVector<Value, 4> promotedOperands; 749 promotedOperands.reserve(operands.size()); 750 useBarePtrCallConv |= options.useBarePtrCallConv; 751 for (auto it : llvm::zip(opOperands, operands)) { 752 auto operand = std::get<0>(it); 753 auto llvmOperand = std::get<1>(it); 754 755 if (useBarePtrCallConv) { 756 // For the bare-ptr calling convention, we only have to extract the 757 // aligned pointer of a memref. 758 if (dyn_cast<MemRefType>(operand.getType())) { 759 MemRefDescriptor desc(llvmOperand); 760 llvmOperand = desc.alignedPtr(builder, loc); 761 } else if (isa<UnrankedMemRefType>(operand.getType())) { 762 llvm_unreachable("Unranked memrefs are not supported"); 763 } 764 } else { 765 if (isa<UnrankedMemRefType>(operand.getType())) { 766 UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, 767 promotedOperands); 768 continue; 769 } 770 if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) { 771 MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType, 772 promotedOperands); 773 continue; 774 } 775 } 776 777 promotedOperands.push_back(llvmOperand); 778 } 779 return promotedOperands; 780 } 781 782 /// Callback to convert function argument types. It converts a MemRef function 783 /// argument to a list of non-aggregate types containing descriptor 784 /// information, and an UnrankedmemRef function argument to a list containing 785 /// the rank and a pointer to a descriptor struct. 786 LogicalResult 787 mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, 788 SmallVectorImpl<Type> &result) { 789 if (auto memref = dyn_cast<MemRefType>(type)) { 790 // In signatures, Memref descriptors are expanded into lists of 791 // non-aggregate values. 792 auto converted = 793 converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true); 794 if (converted.empty()) 795 return failure(); 796 result.append(converted.begin(), converted.end()); 797 return success(); 798 } 799 if (isa<UnrankedMemRefType>(type)) { 800 auto converted = converter.getUnrankedMemRefDescriptorFields(); 801 if (converted.empty()) 802 return failure(); 803 result.append(converted.begin(), converted.end()); 804 return success(); 805 } 806 auto converted = converter.convertType(type); 807 if (!converted) 808 return failure(); 809 result.push_back(converted); 810 return success(); 811 } 812 813 /// Callback to convert function argument types. It converts MemRef function 814 /// arguments to bare pointers to the MemRef element type. 815 LogicalResult 816 mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, 817 SmallVectorImpl<Type> &result) { 818 auto llvmTy = converter.convertCallingConventionType( 819 type, /*useBarePointerCallConv=*/true); 820 if (!llvmTy) 821 return failure(); 822 823 result.push_back(llvmTy); 824 return success(); 825 } 826