1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the translation between an MLIR LLVM dialect module and 10 // the corresponding LLVMIR module. It only handles core LLVM IR operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 15 16 #include "AttrKindDetail.h" 17 #include "DebugTranslation.h" 18 #include "LoopAnnotationTranslation.h" 19 #include "mlir/Dialect/DLTI/DLTI.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" 22 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" 23 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 24 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" 25 #include "mlir/IR/AttrTypeSubElements.h" 26 #include "mlir/IR/Attributes.h" 27 #include "mlir/IR/BuiltinOps.h" 28 #include "mlir/IR/BuiltinTypes.h" 29 #include "mlir/IR/RegionGraphTraits.h" 30 #include "mlir/Support/LLVM.h" 31 #include "mlir/Support/LogicalResult.h" 32 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" 33 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 34 #include "mlir/Transforms/RegionUtils.h" 35 36 #include "llvm/ADT/PostOrderIterator.h" 37 #include "llvm/ADT/SetVector.h" 38 #include "llvm/ADT/TypeSwitch.h" 39 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 40 #include "llvm/IR/BasicBlock.h" 41 #include "llvm/IR/CFG.h" 42 #include "llvm/IR/Constants.h" 43 #include "llvm/IR/DerivedTypes.h" 44 #include "llvm/IR/IRBuilder.h" 45 #include "llvm/IR/InlineAsm.h" 46 #include "llvm/IR/IntrinsicsNVPTX.h" 47 #include "llvm/IR/LLVMContext.h" 48 #include "llvm/IR/MDBuilder.h" 49 #include "llvm/IR/Module.h" 50 #include "llvm/IR/Verifier.h" 51 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 52 #include "llvm/Transforms/Utils/Cloning.h" 53 #include "llvm/Transforms/Utils/ModuleUtils.h" 54 #include <optional> 55 56 using namespace mlir; 57 using namespace mlir::LLVM; 58 using namespace mlir::LLVM::detail; 59 60 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc" 61 62 namespace { 63 /// A customized inserter for LLVM's IRBuilder that captures all LLVM IR 64 /// instructions that are created for future reference. 65 /// 66 /// This is intended to be used with the `CollectionScope` RAII object: 67 /// 68 /// llvm::IRBuilder<..., InstructionCapturingInserter> builder; 69 /// { 70 /// InstructionCapturingInserter::CollectionScope scope(builder); 71 /// // Call IRBuilder methods as usual. 72 /// 73 /// // This will return a list of all instructions created by the builder, 74 /// // in order of creation. 75 /// builder.getInserter().getCapturedInstructions(); 76 /// } 77 /// // This will return an empty list. 78 /// builder.getInserter().getCapturedInstructions(); 79 /// 80 /// The capturing functionality is _disabled_ by default for performance 81 /// consideration. It needs to be explicitly enabled, which is achieved by 82 /// creating a `CollectionScope`. 83 class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter { 84 public: 85 /// Constructs the inserter. 86 InstructionCapturingInserter() 87 : llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) { 88 if (LLVM_LIKELY(enabled)) 89 capturedInstructions.push_back(instruction); 90 }) {} 91 92 /// Returns the list of LLVM IR instructions captured since the last cleanup. 93 ArrayRef<llvm::Instruction *> getCapturedInstructions() const { 94 return capturedInstructions; 95 } 96 97 /// Clears the list of captured LLVM IR instructions. 98 void clearCapturedInstructions() { capturedInstructions.clear(); } 99 100 /// RAII object enabling the capture of created LLVM IR instructions. 101 class CollectionScope { 102 public: 103 /// Creates the scope for the given inserter. 104 CollectionScope(llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing); 105 106 /// Ends the scope. 107 ~CollectionScope(); 108 109 ArrayRef<llvm::Instruction *> getCapturedInstructions() { 110 if (!inserter) 111 return {}; 112 return inserter->getCapturedInstructions(); 113 } 114 115 private: 116 /// Back reference to the inserter. 117 InstructionCapturingInserter *inserter = nullptr; 118 119 /// List of instructions in the inserter prior to this scope. 120 SmallVector<llvm::Instruction *> previouslyCollectedInstructions; 121 122 /// Whether the inserter was enabled prior to this scope. 123 bool wasEnabled; 124 }; 125 126 /// Enable or disable the capturing mechanism. 127 void setEnabled(bool enabled = true) { this->enabled = enabled; } 128 129 private: 130 /// List of captured instructions. 131 SmallVector<llvm::Instruction *> capturedInstructions; 132 133 /// Whether the collection is enabled. 134 bool enabled = false; 135 }; 136 137 using CapturingIRBuilder = 138 llvm::IRBuilder<llvm::ConstantFolder, InstructionCapturingInserter>; 139 } // namespace 140 141 InstructionCapturingInserter::CollectionScope::CollectionScope( 142 llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing) { 143 144 if (!isBuilderCapturing) 145 return; 146 147 auto &capturingIRBuilder = static_cast<CapturingIRBuilder &>(irBuilder); 148 inserter = &capturingIRBuilder.getInserter(); 149 wasEnabled = inserter->enabled; 150 if (wasEnabled) 151 previouslyCollectedInstructions.swap(inserter->capturedInstructions); 152 inserter->setEnabled(true); 153 } 154 155 InstructionCapturingInserter::CollectionScope::~CollectionScope() { 156 if (!inserter) 157 return; 158 159 previouslyCollectedInstructions.swap(inserter->capturedInstructions); 160 // If collection was enabled (likely in another, surrounding scope), keep 161 // the instructions collected in this scope. 162 if (wasEnabled) { 163 llvm::append_range(inserter->capturedInstructions, 164 previouslyCollectedInstructions); 165 } 166 inserter->setEnabled(wasEnabled); 167 } 168 169 /// Translates the given data layout spec attribute to the LLVM IR data layout. 170 /// Only integer, float, pointer and endianness entries are currently supported. 171 static FailureOr<llvm::DataLayout> 172 translateDataLayout(DataLayoutSpecInterface attribute, 173 const DataLayout &dataLayout, 174 std::optional<Location> loc = std::nullopt) { 175 if (!loc) 176 loc = UnknownLoc::get(attribute.getContext()); 177 178 // Translate the endianness attribute. 179 std::string llvmDataLayout; 180 llvm::raw_string_ostream layoutStream(llvmDataLayout); 181 for (DataLayoutEntryInterface entry : attribute.getEntries()) { 182 auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey()); 183 if (!key) 184 continue; 185 if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) { 186 auto value = cast<StringAttr>(entry.getValue()); 187 bool isLittleEndian = 188 value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle; 189 layoutStream << "-" << (isLittleEndian ? "e" : "E"); 190 layoutStream.flush(); 191 continue; 192 } 193 if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) { 194 auto value = cast<IntegerAttr>(entry.getValue()); 195 uint64_t space = value.getValue().getZExtValue(); 196 // Skip the default address space. 197 if (space == 0) 198 continue; 199 layoutStream << "-A" << space; 200 layoutStream.flush(); 201 continue; 202 } 203 if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) { 204 auto value = cast<IntegerAttr>(entry.getValue()); 205 uint64_t alignment = value.getValue().getZExtValue(); 206 // Skip the default stack alignment. 207 if (alignment == 0) 208 continue; 209 layoutStream << "-S" << alignment; 210 layoutStream.flush(); 211 continue; 212 } 213 emitError(*loc) << "unsupported data layout key " << key; 214 return failure(); 215 } 216 217 // Go through the list of entries to check which types are explicitly 218 // specified in entries. Where possible, data layout queries are used instead 219 // of directly inspecting the entries. 220 for (DataLayoutEntryInterface entry : attribute.getEntries()) { 221 auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()); 222 if (!type) 223 continue; 224 // Data layout for the index type is irrelevant at this point. 225 if (isa<IndexType>(type)) 226 continue; 227 layoutStream << "-"; 228 LogicalResult result = 229 llvm::TypeSwitch<Type, LogicalResult>(type) 230 .Case<IntegerType, Float16Type, Float32Type, Float64Type, 231 Float80Type, Float128Type>([&](Type type) -> LogicalResult { 232 if (auto intType = dyn_cast<IntegerType>(type)) { 233 if (intType.getSignedness() != IntegerType::Signless) 234 return emitError(*loc) 235 << "unsupported data layout for non-signless integer " 236 << intType; 237 layoutStream << "i"; 238 } else { 239 layoutStream << "f"; 240 } 241 uint64_t size = dataLayout.getTypeSizeInBits(type); 242 uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u; 243 uint64_t preferred = 244 dataLayout.getTypePreferredAlignment(type) * 8u; 245 layoutStream << size << ":" << abi; 246 if (abi != preferred) 247 layoutStream << ":" << preferred; 248 return success(); 249 }) 250 .Case([&](LLVMPointerType ptrType) { 251 layoutStream << "p" << ptrType.getAddressSpace() << ":"; 252 uint64_t size = dataLayout.getTypeSizeInBits(type); 253 uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u; 254 uint64_t preferred = 255 dataLayout.getTypePreferredAlignment(type) * 8u; 256 layoutStream << size << ":" << abi << ":" << preferred; 257 if (std::optional<uint64_t> index = extractPointerSpecValue( 258 entry.getValue(), PtrDLEntryPos::Index)) 259 layoutStream << ":" << *index; 260 return success(); 261 }) 262 .Default([loc](Type type) { 263 return emitError(*loc) 264 << "unsupported type in data layout: " << type; 265 }); 266 if (failed(result)) 267 return failure(); 268 } 269 layoutStream.flush(); 270 StringRef layoutSpec(llvmDataLayout); 271 if (layoutSpec.starts_with("-")) 272 layoutSpec = layoutSpec.drop_front(); 273 274 return llvm::DataLayout(layoutSpec); 275 } 276 277 /// Builds a constant of a sequential LLVM type `type`, potentially containing 278 /// other sequential types recursively, from the individual constant values 279 /// provided in `constants`. `shape` contains the number of elements in nested 280 /// sequential types. Reports errors at `loc` and returns nullptr on error. 281 static llvm::Constant * 282 buildSequentialConstant(ArrayRef<llvm::Constant *> &constants, 283 ArrayRef<int64_t> shape, llvm::Type *type, 284 Location loc) { 285 if (shape.empty()) { 286 llvm::Constant *result = constants.front(); 287 constants = constants.drop_front(); 288 return result; 289 } 290 291 llvm::Type *elementType; 292 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) { 293 elementType = arrayTy->getElementType(); 294 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) { 295 elementType = vectorTy->getElementType(); 296 } else { 297 emitError(loc) << "expected sequential LLVM types wrapping a scalar"; 298 return nullptr; 299 } 300 301 SmallVector<llvm::Constant *, 8> nested; 302 nested.reserve(shape.front()); 303 for (int64_t i = 0; i < shape.front(); ++i) { 304 nested.push_back(buildSequentialConstant(constants, shape.drop_front(), 305 elementType, loc)); 306 if (!nested.back()) 307 return nullptr; 308 } 309 310 if (shape.size() == 1 && type->isVectorTy()) 311 return llvm::ConstantVector::get(nested); 312 return llvm::ConstantArray::get( 313 llvm::ArrayType::get(elementType, shape.front()), nested); 314 } 315 316 /// Returns the first non-sequential type nested in sequential types. 317 static llvm::Type *getInnermostElementType(llvm::Type *type) { 318 do { 319 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) { 320 type = arrayTy->getElementType(); 321 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) { 322 type = vectorTy->getElementType(); 323 } else { 324 return type; 325 } 326 } while (true); 327 } 328 329 /// Convert a dense elements attribute to an LLVM IR constant using its raw data 330 /// storage if possible. This supports elements attributes of tensor or vector 331 /// type and avoids constructing separate objects for individual values of the 332 /// innermost dimension. Constants for other dimensions are still constructed 333 /// recursively. Returns null if constructing from raw data is not supported for 334 /// this type, e.g., element type is not a power-of-two-sized primitive. Reports 335 /// other errors at `loc`. 336 static llvm::Constant * 337 convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, 338 llvm::Type *llvmType, 339 const ModuleTranslation &moduleTranslation) { 340 if (!denseElementsAttr) 341 return nullptr; 342 343 llvm::Type *innermostLLVMType = getInnermostElementType(llvmType); 344 if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType)) 345 return nullptr; 346 347 ShapedType type = denseElementsAttr.getType(); 348 if (type.getNumElements() == 0) 349 return nullptr; 350 351 // Check that the raw data size matches what is expected for the scalar size. 352 // TODO: in theory, we could repack the data here to keep constructing from 353 // raw data. 354 // TODO: we may also need to consider endianness when cross-compiling to an 355 // architecture where it is different. 356 int64_t elementByteSize = denseElementsAttr.getRawData().size() / 357 denseElementsAttr.getNumElements(); 358 if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) 359 return nullptr; 360 361 // Compute the shape of all dimensions but the innermost. Note that the 362 // innermost dimension may be that of the vector element type. 363 bool hasVectorElementType = isa<VectorType>(type.getElementType()); 364 int64_t numAggregates = 365 denseElementsAttr.getNumElements() / 366 (hasVectorElementType ? 1 367 : denseElementsAttr.getType().getShape().back()); 368 ArrayRef<int64_t> outerShape = type.getShape(); 369 if (!hasVectorElementType) 370 outerShape = outerShape.drop_back(); 371 372 // Handle the case of vector splat, LLVM has special support for it. 373 if (denseElementsAttr.isSplat() && 374 (isa<VectorType>(type) || hasVectorElementType)) { 375 llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( 376 innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc, 377 moduleTranslation); 378 llvm::Constant *splatVector = 379 llvm::ConstantDataVector::getSplat(0, splatValue); 380 SmallVector<llvm::Constant *> constants(numAggregates, splatVector); 381 ArrayRef<llvm::Constant *> constantsRef = constants; 382 return buildSequentialConstant(constantsRef, outerShape, llvmType, loc); 383 } 384 if (denseElementsAttr.isSplat()) 385 return nullptr; 386 387 // In case of non-splat, create a constructor for the innermost constant from 388 // a piece of raw data. 389 std::function<llvm::Constant *(StringRef)> buildCstData; 390 if (isa<TensorType>(type)) { 391 auto vectorElementType = dyn_cast<VectorType>(type.getElementType()); 392 if (vectorElementType && vectorElementType.getRank() == 1) { 393 buildCstData = [&](StringRef data) { 394 return llvm::ConstantDataVector::getRaw( 395 data, vectorElementType.getShape().back(), innermostLLVMType); 396 }; 397 } else if (!vectorElementType) { 398 buildCstData = [&](StringRef data) { 399 return llvm::ConstantDataArray::getRaw(data, type.getShape().back(), 400 innermostLLVMType); 401 }; 402 } 403 } else if (isa<VectorType>(type)) { 404 buildCstData = [&](StringRef data) { 405 return llvm::ConstantDataVector::getRaw(data, type.getShape().back(), 406 innermostLLVMType); 407 }; 408 } 409 if (!buildCstData) 410 return nullptr; 411 412 // Create innermost constants and defer to the default constant creation 413 // mechanism for other dimensions. 414 SmallVector<llvm::Constant *> constants; 415 int64_t aggregateSize = denseElementsAttr.getType().getShape().back() * 416 (innermostLLVMType->getScalarSizeInBits() / 8); 417 constants.reserve(numAggregates); 418 for (unsigned i = 0; i < numAggregates; ++i) { 419 StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize, 420 aggregateSize); 421 constants.push_back(buildCstData(data)); 422 } 423 424 ArrayRef<llvm::Constant *> constantsRef = constants; 425 return buildSequentialConstant(constantsRef, outerShape, llvmType, loc); 426 } 427 428 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. 429 /// This currently supports integer, floating point, splat and dense element 430 /// attributes and combinations thereof. Also, an array attribute with two 431 /// elements is supported to represent a complex constant. In case of error, 432 /// report it to `loc` and return nullptr. 433 llvm::Constant *mlir::LLVM::detail::getLLVMConstant( 434 llvm::Type *llvmType, Attribute attr, Location loc, 435 const ModuleTranslation &moduleTranslation) { 436 if (!attr) 437 return llvm::UndefValue::get(llvmType); 438 if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { 439 auto arrayAttr = dyn_cast<ArrayAttr>(attr); 440 if (!arrayAttr || arrayAttr.size() != 2) { 441 emitError(loc, "expected struct type to be a complex number"); 442 return nullptr; 443 } 444 llvm::Type *elementType = structType->getElementType(0); 445 llvm::Constant *real = 446 getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation); 447 if (!real) 448 return nullptr; 449 llvm::Constant *imag = 450 getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation); 451 if (!imag) 452 return nullptr; 453 return llvm::ConstantStruct::get(structType, {real, imag}); 454 } 455 // For integer types, we allow a mismatch in sizes as the index type in 456 // MLIR might have a different size than the index type in the LLVM module. 457 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) 458 return llvm::ConstantInt::get( 459 llvmType, 460 intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); 461 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) { 462 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); 463 // Special case for 8-bit floats, which are represented by integers due to 464 // the lack of native fp8 types in LLVM at the moment. Additionally, handle 465 // targets (like AMDGPU) that don't implement bfloat and convert all bfloats 466 // to i16. 467 unsigned floatWidth = APFloat::getSizeInBits(sem); 468 if (llvmType->isIntegerTy(floatWidth)) 469 return llvm::ConstantInt::get(llvmType, 470 floatAttr.getValue().bitcastToAPInt()); 471 if (llvmType != 472 llvm::Type::getFloatingPointTy(llvmType->getContext(), 473 floatAttr.getValue().getSemantics())) { 474 emitError(loc, "FloatAttr does not match expected type of the constant"); 475 return nullptr; 476 } 477 return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); 478 } 479 if (auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr)) 480 return llvm::ConstantExpr::getBitCast( 481 moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType); 482 if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) { 483 llvm::Type *elementType; 484 uint64_t numElements; 485 bool isScalable = false; 486 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) { 487 elementType = arrayTy->getElementType(); 488 numElements = arrayTy->getNumElements(); 489 } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) { 490 elementType = fVectorTy->getElementType(); 491 numElements = fVectorTy->getNumElements(); 492 } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) { 493 elementType = sVectorTy->getElementType(); 494 numElements = sVectorTy->getMinNumElements(); 495 isScalable = true; 496 } else { 497 llvm_unreachable("unrecognized constant vector type"); 498 } 499 // Splat value is a scalar. Extract it only if the element type is not 500 // another sequence type. The recursion terminates because each step removes 501 // one outer sequential type. 502 bool elementTypeSequential = 503 isa<llvm::ArrayType, llvm::VectorType>(elementType); 504 llvm::Constant *child = getLLVMConstant( 505 elementType, 506 elementTypeSequential ? splatAttr 507 : splatAttr.getSplatValue<Attribute>(), 508 loc, moduleTranslation); 509 if (!child) 510 return nullptr; 511 if (llvmType->isVectorTy()) 512 return llvm::ConstantVector::getSplat( 513 llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child); 514 if (llvmType->isArrayTy()) { 515 auto *arrayType = llvm::ArrayType::get(elementType, numElements); 516 SmallVector<llvm::Constant *, 8> constants(numElements, child); 517 return llvm::ConstantArray::get(arrayType, constants); 518 } 519 } 520 521 // Try using raw elements data if possible. 522 if (llvm::Constant *result = 523 convertDenseElementsAttr(loc, dyn_cast<DenseElementsAttr>(attr), 524 llvmType, moduleTranslation)) { 525 return result; 526 } 527 528 // Fall back to element-by-element construction otherwise. 529 if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) { 530 assert(elementsAttr.getShapedType().hasStaticShape()); 531 assert(!elementsAttr.getShapedType().getShape().empty() && 532 "unexpected empty elements attribute shape"); 533 534 SmallVector<llvm::Constant *, 8> constants; 535 constants.reserve(elementsAttr.getNumElements()); 536 llvm::Type *innermostType = getInnermostElementType(llvmType); 537 for (auto n : elementsAttr.getValues<Attribute>()) { 538 constants.push_back( 539 getLLVMConstant(innermostType, n, loc, moduleTranslation)); 540 if (!constants.back()) 541 return nullptr; 542 } 543 ArrayRef<llvm::Constant *> constantsRef = constants; 544 llvm::Constant *result = buildSequentialConstant( 545 constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc); 546 assert(constantsRef.empty() && "did not consume all elemental constants"); 547 return result; 548 } 549 550 if (auto stringAttr = dyn_cast<StringAttr>(attr)) { 551 return llvm::ConstantDataArray::get( 552 moduleTranslation.getLLVMContext(), 553 ArrayRef<char>{stringAttr.getValue().data(), 554 stringAttr.getValue().size()}); 555 } 556 emitError(loc, "unsupported constant value"); 557 return nullptr; 558 } 559 560 ModuleTranslation::ModuleTranslation(Operation *module, 561 std::unique_ptr<llvm::Module> llvmModule) 562 : mlirModule(module), llvmModule(std::move(llvmModule)), 563 debugTranslation( 564 std::make_unique<DebugTranslation>(module, *this->llvmModule)), 565 loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>( 566 *this, *this->llvmModule)), 567 typeTranslator(this->llvmModule->getContext()), 568 iface(module->getContext()) { 569 assert(satisfiesLLVMModule(mlirModule) && 570 "mlirModule should honor LLVM's module semantics."); 571 } 572 573 ModuleTranslation::~ModuleTranslation() { 574 if (ompBuilder) 575 ompBuilder->finalize(); 576 } 577 578 void ModuleTranslation::forgetMapping(Region ®ion) { 579 SmallVector<Region *> toProcess; 580 toProcess.push_back(®ion); 581 while (!toProcess.empty()) { 582 Region *current = toProcess.pop_back_val(); 583 for (Block &block : *current) { 584 blockMapping.erase(&block); 585 for (Value arg : block.getArguments()) 586 valueMapping.erase(arg); 587 for (Operation &op : block) { 588 for (Value value : op.getResults()) 589 valueMapping.erase(value); 590 if (op.hasSuccessors()) 591 branchMapping.erase(&op); 592 if (isa<LLVM::GlobalOp>(op)) 593 globalsMapping.erase(&op); 594 llvm::append_range( 595 toProcess, 596 llvm::map_range(op.getRegions(), [](Region &r) { return &r; })); 597 } 598 } 599 } 600 } 601 602 /// Get the SSA value passed to the current block from the terminator operation 603 /// of its predecessor. 604 static Value getPHISourceValue(Block *current, Block *pred, 605 unsigned numArguments, unsigned index) { 606 Operation &terminator = *pred->getTerminator(); 607 if (isa<LLVM::BrOp>(terminator)) 608 return terminator.getOperand(index); 609 610 #ifndef NDEBUG 611 llvm::SmallPtrSet<Block *, 4> seenSuccessors; 612 for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) { 613 Block *successor = terminator.getSuccessor(i); 614 auto branch = cast<BranchOpInterface>(terminator); 615 SuccessorOperands successorOperands = branch.getSuccessorOperands(i); 616 assert( 617 (!seenSuccessors.contains(successor) || successorOperands.empty()) && 618 "successors with arguments in LLVM branches must be different blocks"); 619 seenSuccessors.insert(successor); 620 } 621 #endif 622 623 // For instructions that branch based on a condition value, we need to take 624 // the operands for the branch that was taken. 625 if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) { 626 // For conditional branches, we take the operands from either the "true" or 627 // the "false" branch. 628 return condBranchOp.getSuccessor(0) == current 629 ? condBranchOp.getTrueDestOperands()[index] 630 : condBranchOp.getFalseDestOperands()[index]; 631 } 632 633 if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) { 634 // For switches, we take the operands from either the default case, or from 635 // the case branch that was taken. 636 if (switchOp.getDefaultDestination() == current) 637 return switchOp.getDefaultOperands()[index]; 638 for (const auto &i : llvm::enumerate(switchOp.getCaseDestinations())) 639 if (i.value() == current) 640 return switchOp.getCaseOperands(i.index())[index]; 641 } 642 643 if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) { 644 return invokeOp.getNormalDest() == current 645 ? invokeOp.getNormalDestOperands()[index] 646 : invokeOp.getUnwindDestOperands()[index]; 647 } 648 649 llvm_unreachable( 650 "only branch, switch or invoke operations can be terminators " 651 "of a block that has successors"); 652 } 653 654 /// Connect the PHI nodes to the results of preceding blocks. 655 void mlir::LLVM::detail::connectPHINodes(Region ®ion, 656 const ModuleTranslation &state) { 657 // Skip the first block, it cannot be branched to and its arguments correspond 658 // to the arguments of the LLVM function. 659 for (Block &bb : llvm::drop_begin(region)) { 660 llvm::BasicBlock *llvmBB = state.lookupBlock(&bb); 661 auto phis = llvmBB->phis(); 662 auto numArguments = bb.getNumArguments(); 663 assert(numArguments == std::distance(phis.begin(), phis.end())); 664 for (auto [index, phiNode] : llvm::enumerate(phis)) { 665 for (auto *pred : bb.getPredecessors()) { 666 // Find the LLVM IR block that contains the converted terminator 667 // instruction and use it in the PHI node. Note that this block is not 668 // necessarily the same as state.lookupBlock(pred), some operations 669 // (in particular, OpenMP operations using OpenMPIRBuilder) may have 670 // split the blocks. 671 llvm::Instruction *terminator = 672 state.lookupBranch(pred->getTerminator()); 673 assert(terminator && "missing the mapping for a terminator"); 674 phiNode.addIncoming(state.lookupValue(getPHISourceValue( 675 &bb, pred, numArguments, index)), 676 terminator->getParent()); 677 } 678 } 679 } 680 } 681 682 llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( 683 llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, 684 ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) { 685 llvm::Module *module = builder.GetInsertBlock()->getModule(); 686 llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, tys); 687 return builder.CreateCall(fn, args); 688 } 689 690 llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( 691 llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation, 692 Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults, 693 ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands, 694 ArrayRef<unsigned> immArgPositions, 695 ArrayRef<StringLiteral> immArgAttrNames) { 696 assert(immArgPositions.size() == immArgAttrNames.size() && 697 "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal " 698 "length"); 699 700 // Map operands and attributes to LLVM values. 701 auto operands = moduleTranslation.lookupValues(intrOp->getOperands()); 702 SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size()); 703 for (auto [immArgPos, immArgName] : 704 llvm::zip(immArgPositions, immArgAttrNames)) { 705 auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName)); 706 assert(attr.getType().isIntOrFloat() && "expected int or float immarg"); 707 auto *type = moduleTranslation.convertType(attr.getType()); 708 args[immArgPos] = LLVM::detail::getLLVMConstant( 709 type, attr, intrOp->getLoc(), moduleTranslation); 710 } 711 unsigned opArg = 0; 712 for (auto &arg : args) { 713 if (!arg) 714 arg = operands[opArg++]; 715 } 716 717 // Resolve overloaded intrinsic declaration. 718 SmallVector<llvm::Type *> overloadedTypes; 719 for (unsigned overloadedResultIdx : overloadedResults) { 720 if (numResults > 1) { 721 // More than one result is mapped to an LLVM struct. 722 overloadedTypes.push_back(moduleTranslation.convertType( 723 llvm::cast<LLVM::LLVMStructType>(intrOp->getResult(0).getType()) 724 .getBody()[overloadedResultIdx])); 725 } else { 726 overloadedTypes.push_back( 727 moduleTranslation.convertType(intrOp->getResult(0).getType())); 728 } 729 } 730 for (unsigned overloadedOperandIdx : overloadedOperands) 731 overloadedTypes.push_back(args[overloadedOperandIdx]->getType()); 732 llvm::Module *module = builder.GetInsertBlock()->getModule(); 733 llvm::Function *llvmIntr = 734 llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes); 735 736 return builder.CreateCall(llvmIntr, args); 737 } 738 739 /// Given a single MLIR operation, create the corresponding LLVM IR operation 740 /// using the `builder`. 741 LogicalResult ModuleTranslation::convertOperation(Operation &op, 742 llvm::IRBuilderBase &builder, 743 bool recordInsertions) { 744 const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op); 745 if (!opIface) 746 return op.emitError("cannot be converted to LLVM IR: missing " 747 "`LLVMTranslationDialectInterface` registration for " 748 "dialect for op: ") 749 << op.getName(); 750 751 InstructionCapturingInserter::CollectionScope scope(builder, 752 recordInsertions); 753 if (failed(opIface->convertOperation(&op, builder, *this))) 754 return op.emitError("LLVM Translation failed for operation: ") 755 << op.getName(); 756 757 return convertDialectAttributes(&op, scope.getCapturedInstructions()); 758 } 759 760 /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes 761 /// to define values corresponding to the MLIR block arguments. These nodes 762 /// are not connected to the source basic blocks, which may not exist yet. Uses 763 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have 764 /// been created for `bb` and included in the block mapping. Inserts new 765 /// instructions at the end of the block and leaves `builder` in a state 766 /// suitable for further insertion into the end of the block. 767 LogicalResult ModuleTranslation::convertBlockImpl(Block &bb, 768 bool ignoreArguments, 769 llvm::IRBuilderBase &builder, 770 bool recordInsertions) { 771 builder.SetInsertPoint(lookupBlock(&bb)); 772 auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram(); 773 774 // Before traversing operations, make block arguments available through 775 // value remapping and PHI nodes, but do not add incoming edges for the PHI 776 // nodes just yet: those values may be defined by this or following blocks. 777 // This step is omitted if "ignoreArguments" is set. The arguments of the 778 // first block have been already made available through the remapping of 779 // LLVM function arguments. 780 if (!ignoreArguments) { 781 auto predecessors = bb.getPredecessors(); 782 unsigned numPredecessors = 783 std::distance(predecessors.begin(), predecessors.end()); 784 for (auto arg : bb.getArguments()) { 785 auto wrappedType = arg.getType(); 786 if (!isCompatibleType(wrappedType)) 787 return emitError(bb.front().getLoc(), 788 "block argument does not have an LLVM type"); 789 llvm::Type *type = convertType(wrappedType); 790 llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors); 791 mapValue(arg, phi); 792 } 793 } 794 795 // Traverse operations. 796 for (auto &op : bb) { 797 // Set the current debug location within the builder. 798 builder.SetCurrentDebugLocation( 799 debugTranslation->translateLoc(op.getLoc(), subprogram)); 800 801 if (failed(convertOperation(op, builder, recordInsertions))) 802 return failure(); 803 804 // Set the branch weight metadata on the translated instruction. 805 if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) 806 setBranchWeightsMetadata(iface); 807 } 808 809 return success(); 810 } 811 812 /// A helper method to get the single Block in an operation honoring LLVM's 813 /// module requirements. 814 static Block &getModuleBody(Operation *module) { 815 return module->getRegion(0).front(); 816 } 817 818 /// A helper method to decide if a constant must not be set as a global variable 819 /// initializer. For an external linkage variable, the variable with an 820 /// initializer is considered externally visible and defined in this module, the 821 /// variable without an initializer is externally available and is defined 822 /// elsewhere. 823 static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage, 824 llvm::Constant *cst) { 825 return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) || 826 linkage == llvm::GlobalVariable::ExternalWeakLinkage; 827 } 828 829 /// Sets the runtime preemption specifier of `gv` to dso_local if 830 /// `dsoLocalRequested` is true, otherwise it is left unchanged. 831 static void addRuntimePreemptionSpecifier(bool dsoLocalRequested, 832 llvm::GlobalValue *gv) { 833 if (dsoLocalRequested) 834 gv->setDSOLocal(true); 835 } 836 837 /// Create named global variables that correspond to llvm.mlir.global 838 /// definitions. Convert llvm.global_ctors and global_dtors ops. 839 LogicalResult ModuleTranslation::convertGlobals() { 840 // Mapping from compile unit to its respective set of global variables. 841 DenseMap<llvm::DICompileUnit *, SmallVector<llvm::Metadata *>> allGVars; 842 843 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) { 844 llvm::Type *type = convertType(op.getType()); 845 llvm::Constant *cst = nullptr; 846 if (op.getValueOrNull()) { 847 // String attributes are treated separately because they cannot appear as 848 // in-function constants and are thus not supported by getLLVMConstant. 849 if (auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) { 850 cst = llvm::ConstantDataArray::getString( 851 llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false); 852 type = cst->getType(); 853 } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc(), 854 *this))) { 855 return failure(); 856 } 857 } 858 859 auto linkage = convertLinkageToLLVM(op.getLinkage()); 860 auto addrSpace = op.getAddrSpace(); 861 862 // LLVM IR requires constant with linkage other than external or weak 863 // external to have initializers. If MLIR does not provide an initializer, 864 // default to undef. 865 bool dropInitializer = shouldDropGlobalInitializer(linkage, cst); 866 if (!dropInitializer && !cst) 867 cst = llvm::UndefValue::get(type); 868 else if (dropInitializer && cst) 869 cst = nullptr; 870 871 auto *var = new llvm::GlobalVariable( 872 *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(), 873 /*InsertBefore=*/nullptr, 874 op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel 875 : llvm::GlobalValue::NotThreadLocal, 876 addrSpace); 877 878 if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) { 879 auto selectorOp = cast<ComdatSelectorOp>( 880 SymbolTable::lookupNearestSymbolFrom(op, *comdat)); 881 var->setComdat(comdatMapping.lookup(selectorOp)); 882 } 883 884 if (op.getUnnamedAddr().has_value()) 885 var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr())); 886 887 if (op.getSection().has_value()) 888 var->setSection(*op.getSection()); 889 890 addRuntimePreemptionSpecifier(op.getDsoLocal(), var); 891 892 std::optional<uint64_t> alignment = op.getAlignment(); 893 if (alignment.has_value()) 894 var->setAlignment(llvm::MaybeAlign(alignment.value())); 895 896 var->setVisibility(convertVisibilityToLLVM(op.getVisibility_())); 897 898 globalsMapping.try_emplace(op, var); 899 900 // Add debug information if present. 901 if (op.getDbgExpr()) { 902 llvm::DIGlobalVariableExpression *diGlobalExpr = 903 debugTranslation->translateGlobalVariableExpression(op.getDbgExpr()); 904 llvm::DIGlobalVariable *diGlobalVar = diGlobalExpr->getVariable(); 905 var->addDebugInfo(diGlobalExpr); 906 907 // Get the compile unit (scope) of the the global variable. 908 if (llvm::DICompileUnit *compileUnit = 909 dyn_cast_if_present<llvm::DICompileUnit>( 910 diGlobalVar->getScope())) { 911 // Update the compile unit with this incoming global variable expression 912 // during the finalizing step later. 913 allGVars[compileUnit].push_back(diGlobalExpr); 914 } 915 } 916 } 917 918 // Convert global variable bodies. This is done after all global variables 919 // have been created in LLVM IR because a global body may refer to another 920 // global or itself. So all global variables need to be mapped first. 921 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) { 922 if (Block *initializer = op.getInitializerBlock()) { 923 llvm::IRBuilder<> builder(llvmModule->getContext()); 924 for (auto &op : initializer->without_terminator()) { 925 if (failed(convertOperation(op, builder)) || 926 !isa<llvm::Constant>(lookupValue(op.getResult(0)))) 927 return emitError(op.getLoc(), "unemittable constant value"); 928 } 929 ReturnOp ret = cast<ReturnOp>(initializer->getTerminator()); 930 llvm::Constant *cst = 931 cast<llvm::Constant>(lookupValue(ret.getOperand(0))); 932 auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op)); 933 if (!shouldDropGlobalInitializer(global->getLinkage(), cst)) 934 global->setInitializer(cst); 935 } 936 } 937 938 // Convert llvm.mlir.global_ctors and dtors. 939 for (Operation &op : getModuleBody(mlirModule)) { 940 auto ctorOp = dyn_cast<GlobalCtorsOp>(op); 941 auto dtorOp = dyn_cast<GlobalDtorsOp>(op); 942 if (!ctorOp && !dtorOp) 943 continue; 944 auto range = ctorOp ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities()) 945 : llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities()); 946 auto appendGlobalFn = 947 ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors; 948 for (auto symbolAndPriority : range) { 949 llvm::Function *f = lookupFunction( 950 cast<FlatSymbolRefAttr>(std::get<0>(symbolAndPriority)).getValue()); 951 appendGlobalFn(*llvmModule, f, 952 cast<IntegerAttr>(std::get<1>(symbolAndPriority)).getInt(), 953 /*Data=*/nullptr); 954 } 955 } 956 957 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) 958 if (failed(convertDialectAttributes(op, {}))) 959 return failure(); 960 961 // Finally, update the compile units their respective sets of global variables 962 // created earlier. 963 for (const auto &[compileUnit, globals] : allGVars) { 964 compileUnit->replaceGlobalVariables( 965 llvm::MDTuple::get(getLLVMContext(), globals)); 966 } 967 968 return success(); 969 } 970 971 /// Attempts to add an attribute identified by `key`, optionally with the given 972 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the 973 /// attribute has a kind known to LLVM IR, create the attribute of this kind, 974 /// otherwise keep it as a string attribute. Performs additional checks for 975 /// attributes known to have or not have a value in order to avoid assertions 976 /// inside LLVM upon construction. 977 static LogicalResult checkedAddLLVMFnAttribute(Location loc, 978 llvm::Function *llvmFunc, 979 StringRef key, 980 StringRef value = StringRef()) { 981 auto kind = llvm::Attribute::getAttrKindFromName(key); 982 if (kind == llvm::Attribute::None) { 983 llvmFunc->addFnAttr(key, value); 984 return success(); 985 } 986 987 if (llvm::Attribute::isIntAttrKind(kind)) { 988 if (value.empty()) 989 return emitError(loc) << "LLVM attribute '" << key << "' expects a value"; 990 991 int64_t result; 992 if (!value.getAsInteger(/*Radix=*/0, result)) 993 llvmFunc->addFnAttr( 994 llvm::Attribute::get(llvmFunc->getContext(), kind, result)); 995 else 996 llvmFunc->addFnAttr(key, value); 997 return success(); 998 } 999 1000 if (!value.empty()) 1001 return emitError(loc) << "LLVM attribute '" << key 1002 << "' does not expect a value, found '" << value 1003 << "'"; 1004 1005 llvmFunc->addFnAttr(kind); 1006 return success(); 1007 } 1008 1009 /// Attaches the attributes listed in the given array attribute to `llvmFunc`. 1010 /// Reports error to `loc` if any and returns immediately. Expects `attributes` 1011 /// to be an array attribute containing either string attributes, treated as 1012 /// value-less LLVM attributes, or array attributes containing two string 1013 /// attributes, with the first string being the name of the corresponding LLVM 1014 /// attribute and the second string beings its value. Note that even integer 1015 /// attributes are expected to have their values expressed as strings. 1016 static LogicalResult 1017 forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes, 1018 llvm::Function *llvmFunc) { 1019 if (!attributes) 1020 return success(); 1021 1022 for (Attribute attr : *attributes) { 1023 if (auto stringAttr = dyn_cast<StringAttr>(attr)) { 1024 if (failed( 1025 checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue()))) 1026 return failure(); 1027 continue; 1028 } 1029 1030 auto arrayAttr = dyn_cast<ArrayAttr>(attr); 1031 if (!arrayAttr || arrayAttr.size() != 2) 1032 return emitError(loc) 1033 << "expected 'passthrough' to contain string or array attributes"; 1034 1035 auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]); 1036 auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]); 1037 if (!keyAttr || !valueAttr) 1038 return emitError(loc) 1039 << "expected arrays within 'passthrough' to contain two strings"; 1040 1041 if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(), 1042 valueAttr.getValue()))) 1043 return failure(); 1044 } 1045 return success(); 1046 } 1047 1048 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { 1049 // Clear the block, branch value mappings, they are only relevant within one 1050 // function. 1051 blockMapping.clear(); 1052 valueMapping.clear(); 1053 branchMapping.clear(); 1054 llvm::Function *llvmFunc = lookupFunction(func.getName()); 1055 1056 // Translate the debug information for this function. 1057 debugTranslation->translate(func, *llvmFunc); 1058 1059 // Add function arguments to the value remapping table. 1060 for (auto [mlirArg, llvmArg] : 1061 llvm::zip(func.getArguments(), llvmFunc->args())) 1062 mapValue(mlirArg, &llvmArg); 1063 1064 // Check the personality and set it. 1065 if (func.getPersonality()) { 1066 llvm::Type *ty = llvm::PointerType::getUnqual(llvmFunc->getContext()); 1067 if (llvm::Constant *pfunc = getLLVMConstant(ty, func.getPersonalityAttr(), 1068 func.getLoc(), *this)) 1069 llvmFunc->setPersonalityFn(pfunc); 1070 } 1071 1072 if (std::optional<StringRef> section = func.getSection()) 1073 llvmFunc->setSection(*section); 1074 1075 if (func.getArmStreaming()) 1076 llvmFunc->addFnAttr("aarch64_pstate_sm_enabled"); 1077 else if (func.getArmLocallyStreaming()) 1078 llvmFunc->addFnAttr("aarch64_pstate_sm_body"); 1079 else if (func.getArmStreamingCompatible()) 1080 llvmFunc->addFnAttr("aarch64_pstate_sm_compatible"); 1081 1082 if (func.getArmNewZa()) 1083 llvmFunc->addFnAttr("aarch64_pstate_za_new"); 1084 1085 if (auto targetFeatures = func.getTargetFeatures()) 1086 llvmFunc->addFnAttr("target-features", targetFeatures->getFeaturesString()); 1087 1088 if (auto attr = func.getVscaleRange()) 1089 llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs( 1090 getLLVMContext(), attr->getMinRange().getInt(), 1091 attr->getMaxRange().getInt())); 1092 1093 // Add function attribute frame-pointer, if found. 1094 if (FramePointerKindAttr attr = func.getFramePointerAttr()) 1095 llvmFunc->addFnAttr("frame-pointer", 1096 LLVM::framePointerKind::stringifyFramePointerKind( 1097 (attr.getFramePointerKind()))); 1098 1099 // First, create all blocks so we can jump to them. 1100 llvm::LLVMContext &llvmContext = llvmFunc->getContext(); 1101 for (auto &bb : func) { 1102 auto *llvmBB = llvm::BasicBlock::Create(llvmContext); 1103 llvmBB->insertInto(llvmFunc); 1104 mapBlock(&bb, llvmBB); 1105 } 1106 1107 // Then, convert blocks one by one in topological order to ensure defs are 1108 // converted before uses. 1109 auto blocks = getTopologicallySortedBlocks(func.getBody()); 1110 for (Block *bb : blocks) { 1111 CapturingIRBuilder builder(llvmContext); 1112 if (failed(convertBlockImpl(*bb, bb->isEntryBlock(), builder, 1113 /*recordInsertions=*/true))) 1114 return failure(); 1115 } 1116 1117 // After all blocks have been traversed and values mapped, connect the PHI 1118 // nodes to the results of preceding blocks. 1119 detail::connectPHINodes(func.getBody(), *this); 1120 1121 // Finally, convert dialect attributes attached to the function. 1122 return convertDialectAttributes(func, {}); 1123 } 1124 1125 LogicalResult ModuleTranslation::convertDialectAttributes( 1126 Operation *op, ArrayRef<llvm::Instruction *> instructions) { 1127 for (NamedAttribute attribute : op->getDialectAttrs()) 1128 if (failed(iface.amendOperation(op, instructions, attribute, *this))) 1129 return failure(); 1130 return success(); 1131 } 1132 1133 /// Converts the function attributes from LLVMFuncOp and attaches them to the 1134 /// llvm::Function. 1135 static void convertFunctionAttributes(LLVMFuncOp func, 1136 llvm::Function *llvmFunc) { 1137 if (!func.getMemory()) 1138 return; 1139 1140 MemoryEffectsAttr memEffects = func.getMemoryAttr(); 1141 1142 // Add memory effects incrementally. 1143 llvm::MemoryEffects newMemEffects = 1144 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem, 1145 convertModRefInfoToLLVM(memEffects.getArgMem())); 1146 newMemEffects |= llvm::MemoryEffects( 1147 llvm::MemoryEffects::Location::InaccessibleMem, 1148 convertModRefInfoToLLVM(memEffects.getInaccessibleMem())); 1149 newMemEffects |= 1150 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other, 1151 convertModRefInfoToLLVM(memEffects.getOther())); 1152 llvmFunc->setMemoryEffects(newMemEffects); 1153 } 1154 1155 llvm::AttrBuilder 1156 ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) { 1157 llvm::AttrBuilder attrBuilder(llvmModule->getContext()); 1158 1159 for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) { 1160 Attribute attr = paramAttrs.get(mlirName); 1161 // Skip attributes that are not present. 1162 if (!attr) 1163 continue; 1164 1165 // NOTE: C++17 does not support capturing structured bindings. 1166 llvm::Attribute::AttrKind llvmKindCap = llvmKind; 1167 1168 llvm::TypeSwitch<Attribute>(attr) 1169 .Case<TypeAttr>([&](auto typeAttr) { 1170 attrBuilder.addTypeAttr(llvmKindCap, 1171 convertType(typeAttr.getValue())); 1172 }) 1173 .Case<IntegerAttr>([&](auto intAttr) { 1174 attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt()); 1175 }) 1176 .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); }); 1177 } 1178 1179 return attrBuilder; 1180 } 1181 1182 LogicalResult ModuleTranslation::convertFunctionSignatures() { 1183 // Declare all functions first because there may be function calls that form a 1184 // call graph with cycles, or global initializers that reference functions. 1185 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { 1186 llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( 1187 function.getName(), 1188 cast<llvm::FunctionType>(convertType(function.getFunctionType()))); 1189 llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee()); 1190 llvmFunc->setLinkage(convertLinkageToLLVM(function.getLinkage())); 1191 llvmFunc->setCallingConv(convertCConvToLLVM(function.getCConv())); 1192 mapFunction(function.getName(), llvmFunc); 1193 addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc); 1194 1195 // Convert function attributes. 1196 convertFunctionAttributes(function, llvmFunc); 1197 1198 // Convert function_entry_count attribute to metadata. 1199 if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount()) 1200 llvmFunc->setEntryCount(entryCount.value()); 1201 1202 // Convert result attributes. 1203 if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) { 1204 DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]); 1205 llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs)); 1206 } 1207 1208 // Convert argument attributes. 1209 for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) { 1210 if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) { 1211 llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs); 1212 llvmArg.addAttrs(attrBuilder); 1213 } 1214 } 1215 1216 // Forward the pass-through attributes to LLVM. 1217 if (failed(forwardPassthroughAttributes( 1218 function.getLoc(), function.getPassthrough(), llvmFunc))) 1219 return failure(); 1220 1221 // Convert visibility attribute. 1222 llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_())); 1223 1224 // Convert the comdat attribute. 1225 if (std::optional<mlir::SymbolRefAttr> comdat = function.getComdat()) { 1226 auto selectorOp = cast<ComdatSelectorOp>( 1227 SymbolTable::lookupNearestSymbolFrom(function, *comdat)); 1228 llvmFunc->setComdat(comdatMapping.lookup(selectorOp)); 1229 } 1230 1231 if (auto gc = function.getGarbageCollector()) 1232 llvmFunc->setGC(gc->str()); 1233 1234 if (auto unnamedAddr = function.getUnnamedAddr()) 1235 llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr)); 1236 1237 if (auto alignment = function.getAlignment()) 1238 llvmFunc->setAlignment(llvm::MaybeAlign(*alignment)); 1239 } 1240 1241 return success(); 1242 } 1243 1244 LogicalResult ModuleTranslation::convertFunctions() { 1245 // Convert functions. 1246 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { 1247 // Do not convert external functions, but do process dialect attributes 1248 // attached to them. 1249 if (function.isExternal()) { 1250 if (failed(convertDialectAttributes(function, {}))) 1251 return failure(); 1252 continue; 1253 } 1254 1255 if (failed(convertOneFunction(function))) 1256 return failure(); 1257 } 1258 1259 return success(); 1260 } 1261 1262 LogicalResult ModuleTranslation::convertComdats() { 1263 for (auto comdatOp : getModuleBody(mlirModule).getOps<ComdatOp>()) { 1264 for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) { 1265 llvm::Module *module = getLLVMModule(); 1266 if (module->getComdatSymbolTable().contains(selectorOp.getSymName())) 1267 return emitError(selectorOp.getLoc()) 1268 << "comdat selection symbols must be unique even in different " 1269 "comdat regions"; 1270 llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName()); 1271 comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat())); 1272 comdatMapping.try_emplace(selectorOp, comdat); 1273 } 1274 } 1275 return success(); 1276 } 1277 1278 void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op, 1279 llvm::Instruction *inst) { 1280 if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op)) 1281 inst->setMetadata(llvm::LLVMContext::MD_access_group, node); 1282 } 1283 1284 llvm::MDNode * 1285 ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) { 1286 auto [scopeIt, scopeInserted] = 1287 aliasScopeMetadataMapping.try_emplace(aliasScopeAttr, nullptr); 1288 if (!scopeInserted) 1289 return scopeIt->second; 1290 llvm::LLVMContext &ctx = llvmModule->getContext(); 1291 // Convert the domain metadata node if necessary. 1292 auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace( 1293 aliasScopeAttr.getDomain(), nullptr); 1294 if (insertedDomain) { 1295 llvm::SmallVector<llvm::Metadata *, 2> operands; 1296 // Placeholder for self-reference. 1297 operands.push_back({}); 1298 if (StringAttr description = aliasScopeAttr.getDomain().getDescription()) 1299 operands.push_back(llvm::MDString::get(ctx, description)); 1300 domainIt->second = llvm::MDNode::get(ctx, operands); 1301 // Self-reference for uniqueness. 1302 domainIt->second->replaceOperandWith(0, domainIt->second); 1303 } 1304 // Convert the scope metadata node. 1305 assert(domainIt->second && "Scope's domain should already be valid"); 1306 llvm::SmallVector<llvm::Metadata *, 3> operands; 1307 // Placeholder for self-reference. 1308 operands.push_back({}); 1309 operands.push_back(domainIt->second); 1310 if (StringAttr description = aliasScopeAttr.getDescription()) 1311 operands.push_back(llvm::MDString::get(ctx, description)); 1312 scopeIt->second = llvm::MDNode::get(ctx, operands); 1313 // Self-reference for uniqueness. 1314 scopeIt->second->replaceOperandWith(0, scopeIt->second); 1315 return scopeIt->second; 1316 } 1317 1318 llvm::MDNode *ModuleTranslation::getOrCreateAliasScopes( 1319 ArrayRef<AliasScopeAttr> aliasScopeAttrs) { 1320 SmallVector<llvm::Metadata *> nodes; 1321 nodes.reserve(aliasScopeAttrs.size()); 1322 for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs) 1323 nodes.push_back(getOrCreateAliasScope(aliasScopeAttr)); 1324 return llvm::MDNode::get(getLLVMContext(), nodes); 1325 } 1326 1327 void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op, 1328 llvm::Instruction *inst) { 1329 auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs, unsigned kind) { 1330 if (!aliasScopeAttrs || aliasScopeAttrs.empty()) 1331 return; 1332 llvm::MDNode *node = getOrCreateAliasScopes( 1333 llvm::to_vector(aliasScopeAttrs.getAsRange<AliasScopeAttr>())); 1334 inst->setMetadata(kind, node); 1335 }; 1336 1337 populateScopeMetadata(op.getAliasScopesOrNull(), 1338 llvm::LLVMContext::MD_alias_scope); 1339 populateScopeMetadata(op.getNoAliasScopesOrNull(), 1340 llvm::LLVMContext::MD_noalias); 1341 } 1342 1343 llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const { 1344 return tbaaMetadataMapping.lookup(tbaaAttr); 1345 } 1346 1347 void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op, 1348 llvm::Instruction *inst) { 1349 ArrayAttr tagRefs = op.getTBAATagsOrNull(); 1350 if (!tagRefs || tagRefs.empty()) 1351 return; 1352 1353 // LLVM IR currently does not support attaching more than one TBAA access tag 1354 // to a memory accessing instruction. It may be useful to support this in 1355 // future, but for the time being just ignore the metadata if MLIR operation 1356 // has multiple access tags. 1357 if (tagRefs.size() > 1) { 1358 op.emitWarning() << "TBAA access tags were not translated, because LLVM " 1359 "IR only supports a single tag per instruction"; 1360 return; 1361 } 1362 1363 llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0])); 1364 inst->setMetadata(llvm::LLVMContext::MD_tbaa, node); 1365 } 1366 1367 void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) { 1368 DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull(); 1369 if (!weightsAttr) 1370 return; 1371 1372 llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op); 1373 assert(inst && "expected the operation to have a mapping to an instruction"); 1374 SmallVector<uint32_t> weights(weightsAttr.asArrayRef()); 1375 inst->setMetadata( 1376 llvm::LLVMContext::MD_prof, 1377 llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights)); 1378 } 1379 1380 LogicalResult ModuleTranslation::createTBAAMetadata() { 1381 llvm::LLVMContext &ctx = llvmModule->getContext(); 1382 llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64); 1383 1384 // Walk the entire module and create all metadata nodes for the TBAA 1385 // attributes. The code below relies on two invariants of the 1386 // `AttrTypeWalker`: 1387 // 1. Attributes are visited in post-order: Since the attributes create a DAG, 1388 // this ensures that any lookups into `tbaaMetadataMapping` for child 1389 // attributes succeed. 1390 // 2. Attributes are only ever visited once: This way we don't leak any 1391 // LLVM metadata instances. 1392 AttrTypeWalker walker; 1393 walker.addWalk([&](TBAARootAttr root) { 1394 tbaaMetadataMapping.insert( 1395 {root, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, root.getId()))}); 1396 }); 1397 1398 walker.addWalk([&](TBAATypeDescriptorAttr descriptor) { 1399 SmallVector<llvm::Metadata *> operands; 1400 operands.push_back(llvm::MDString::get(ctx, descriptor.getId())); 1401 for (TBAAMemberAttr member : descriptor.getMembers()) { 1402 operands.push_back(tbaaMetadataMapping.lookup(member.getTypeDesc())); 1403 operands.push_back(llvm::ConstantAsMetadata::get( 1404 llvm::ConstantInt::get(offsetTy, member.getOffset()))); 1405 } 1406 1407 tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)}); 1408 }); 1409 1410 walker.addWalk([&](TBAATagAttr tag) { 1411 SmallVector<llvm::Metadata *> operands; 1412 1413 operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType())); 1414 operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType())); 1415 1416 operands.push_back(llvm::ConstantAsMetadata::get( 1417 llvm::ConstantInt::get(offsetTy, tag.getOffset()))); 1418 if (tag.getConstant()) 1419 operands.push_back( 1420 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(offsetTy, 1))); 1421 1422 tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)}); 1423 }); 1424 1425 mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) { 1426 if (auto attr = analysisOpInterface.getTBAATagsOrNull()) 1427 walker.walk(attr); 1428 }); 1429 1430 return success(); 1431 } 1432 1433 void ModuleTranslation::setLoopMetadata(Operation *op, 1434 llvm::Instruction *inst) { 1435 LoopAnnotationAttr attr = 1436 TypeSwitch<Operation *, LoopAnnotationAttr>(op) 1437 .Case<LLVM::BrOp, LLVM::CondBrOp>( 1438 [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); }); 1439 if (!attr) 1440 return; 1441 llvm::MDNode *loopMD = 1442 loopAnnotationTranslation->translateLoopAnnotation(attr, op); 1443 inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD); 1444 } 1445 1446 llvm::Type *ModuleTranslation::convertType(Type type) { 1447 return typeTranslator.translateType(type); 1448 } 1449 1450 /// A helper to look up remapped operands in the value remapping table. 1451 SmallVector<llvm::Value *> ModuleTranslation::lookupValues(ValueRange values) { 1452 SmallVector<llvm::Value *> remapped; 1453 remapped.reserve(values.size()); 1454 for (Value v : values) 1455 remapped.push_back(lookupValue(v)); 1456 return remapped; 1457 } 1458 1459 llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() { 1460 if (!ompBuilder) { 1461 ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule); 1462 ompBuilder->initialize(); 1463 1464 // Flags represented as top-level OpenMP dialect attributes are set in 1465 // `OpenMPDialectLLVMIRTranslationInterface::amendOperation()`. Here we set 1466 // the default configuration. 1467 ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig( 1468 /* IsTargetDevice = */ false, /* IsGPU = */ false, 1469 /* OpenMPOffloadMandatory = */ false, 1470 /* HasRequiresReverseOffload = */ false, 1471 /* HasRequiresUnifiedAddress = */ false, 1472 /* HasRequiresUnifiedSharedMemory = */ false, 1473 /* HasRequiresDynamicAllocators = */ false)); 1474 } 1475 return ompBuilder.get(); 1476 } 1477 1478 llvm::DILocation *ModuleTranslation::translateLoc(Location loc, 1479 llvm::DILocalScope *scope) { 1480 return debugTranslation->translateLoc(loc, scope); 1481 } 1482 1483 llvm::DIExpression * 1484 ModuleTranslation::translateExpression(LLVM::DIExpressionAttr attr) { 1485 return debugTranslation->translateExpression(attr); 1486 } 1487 1488 llvm::DIGlobalVariableExpression * 1489 ModuleTranslation::translateGlobalVariableExpression( 1490 LLVM::DIGlobalVariableExpressionAttr attr) { 1491 return debugTranslation->translateGlobalVariableExpression(attr); 1492 } 1493 1494 llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) { 1495 return debugTranslation->translate(attr); 1496 } 1497 1498 llvm::NamedMDNode * 1499 ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) { 1500 return llvmModule->getOrInsertNamedMetadata(name); 1501 } 1502 1503 void ModuleTranslation::StackFrame::anchor() {} 1504 1505 static std::unique_ptr<llvm::Module> 1506 prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, 1507 StringRef name) { 1508 m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>(); 1509 auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext); 1510 if (auto dataLayoutAttr = 1511 m->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) { 1512 llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue()); 1513 } else { 1514 FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout("")); 1515 if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) { 1516 if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) { 1517 llvmDataLayout = 1518 translateDataLayout(spec, DataLayout(iface), m->getLoc()); 1519 } 1520 } else if (auto mod = dyn_cast<ModuleOp>(m)) { 1521 if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) { 1522 llvmDataLayout = 1523 translateDataLayout(spec, DataLayout(mod), m->getLoc()); 1524 } 1525 } 1526 if (failed(llvmDataLayout)) 1527 return nullptr; 1528 llvmModule->setDataLayout(*llvmDataLayout); 1529 } 1530 if (auto targetTripleAttr = 1531 m->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) 1532 llvmModule->setTargetTriple(cast<StringAttr>(targetTripleAttr).getValue()); 1533 1534 return llvmModule; 1535 } 1536 1537 std::unique_ptr<llvm::Module> 1538 mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, 1539 StringRef name) { 1540 if (!satisfiesLLVMModule(module)) { 1541 module->emitOpError("can not be translated to an LLVMIR module"); 1542 return nullptr; 1543 } 1544 1545 std::unique_ptr<llvm::Module> llvmModule = 1546 prepareLLVMModule(module, llvmContext, name); 1547 if (!llvmModule) 1548 return nullptr; 1549 1550 LLVM::ensureDistinctSuccessors(module); 1551 1552 ModuleTranslation translator(module, std::move(llvmModule)); 1553 llvm::IRBuilder<> llvmBuilder(llvmContext); 1554 1555 // Convert module before functions and operations inside, so dialect 1556 // attributes can be used to change dialect-specific global configurations via 1557 // `amendOperation()`. These configurations can then influence the translation 1558 // of operations afterwards. 1559 if (failed(translator.convertOperation(*module, llvmBuilder))) 1560 return nullptr; 1561 1562 if (failed(translator.convertComdats())) 1563 return nullptr; 1564 if (failed(translator.convertFunctionSignatures())) 1565 return nullptr; 1566 if (failed(translator.convertGlobals())) 1567 return nullptr; 1568 if (failed(translator.createTBAAMetadata())) 1569 return nullptr; 1570 1571 // Convert other top-level operations if possible. 1572 for (Operation &o : getModuleBody(module).getOperations()) { 1573 if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp, 1574 LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) && 1575 !o.hasTrait<OpTrait::IsTerminator>() && 1576 failed(translator.convertOperation(o, llvmBuilder))) { 1577 return nullptr; 1578 } 1579 } 1580 1581 // Operations in function bodies with symbolic references must be converted 1582 // after the top-level operations they refer to are declared, so we do it 1583 // last. 1584 if (failed(translator.convertFunctions())) 1585 return nullptr; 1586 1587 if (llvm::verifyModule(*translator.llvmModule, &llvm::errs())) 1588 return nullptr; 1589 1590 return std::move(translator.llvmModule); 1591 } 1592