1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the translation between an MLIR LLVM dialect module and 10 // the corresponding LLVMIR module. It only handles core LLVM IR operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 15 16 #include "AttrKindDetail.h" 17 #include "DebugTranslation.h" 18 #include "LoopAnnotationTranslation.h" 19 #include "mlir/Dialect/DLTI/DLTI.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" 22 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" 23 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 24 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" 25 #include "mlir/IR/AttrTypeSubElements.h" 26 #include "mlir/IR/Attributes.h" 27 #include "mlir/IR/BuiltinOps.h" 28 #include "mlir/IR/BuiltinTypes.h" 29 #include "mlir/IR/RegionGraphTraits.h" 30 #include "mlir/Support/LLVM.h" 31 #include "mlir/Support/LogicalResult.h" 32 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" 33 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 34 35 #include "llvm/ADT/PostOrderIterator.h" 36 #include "llvm/ADT/SetVector.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 39 #include "llvm/IR/BasicBlock.h" 40 #include "llvm/IR/CFG.h" 41 #include "llvm/IR/Constants.h" 42 #include "llvm/IR/DerivedTypes.h" 43 #include "llvm/IR/IRBuilder.h" 44 #include "llvm/IR/InlineAsm.h" 45 #include "llvm/IR/IntrinsicsNVPTX.h" 46 #include "llvm/IR/LLVMContext.h" 47 #include "llvm/IR/MDBuilder.h" 48 #include "llvm/IR/Module.h" 49 #include "llvm/IR/Verifier.h" 50 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 51 #include "llvm/Transforms/Utils/Cloning.h" 52 #include "llvm/Transforms/Utils/ModuleUtils.h" 53 #include <optional> 54 55 using namespace mlir; 56 using namespace mlir::LLVM; 57 using namespace mlir::LLVM::detail; 58 59 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc" 60 61 /// Translates the given data layout spec attribute to the LLVM IR data layout. 62 /// Only integer, float, pointer and endianness entries are currently supported. 63 static FailureOr<llvm::DataLayout> 64 translateDataLayout(DataLayoutSpecInterface attribute, 65 const DataLayout &dataLayout, 66 std::optional<Location> loc = std::nullopt) { 67 if (!loc) 68 loc = UnknownLoc::get(attribute.getContext()); 69 70 // Translate the endianness attribute. 71 std::string llvmDataLayout; 72 llvm::raw_string_ostream layoutStream(llvmDataLayout); 73 for (DataLayoutEntryInterface entry : attribute.getEntries()) { 74 auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey()); 75 if (!key) 76 continue; 77 if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) { 78 auto value = cast<StringAttr>(entry.getValue()); 79 bool isLittleEndian = 80 value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle; 81 layoutStream << "-" << (isLittleEndian ? "e" : "E"); 82 layoutStream.flush(); 83 continue; 84 } 85 if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) { 86 auto value = cast<IntegerAttr>(entry.getValue()); 87 uint64_t space = value.getValue().getZExtValue(); 88 // Skip the default address space. 89 if (space == 0) 90 continue; 91 layoutStream << "-A" << space; 92 layoutStream.flush(); 93 continue; 94 } 95 if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) { 96 auto value = cast<IntegerAttr>(entry.getValue()); 97 uint64_t alignment = value.getValue().getZExtValue(); 98 // Skip the default stack alignment. 99 if (alignment == 0) 100 continue; 101 layoutStream << "-S" << alignment; 102 layoutStream.flush(); 103 continue; 104 } 105 emitError(*loc) << "unsupported data layout key " << key; 106 return failure(); 107 } 108 109 // Go through the list of entries to check which types are explicitly 110 // specified in entries. Where possible, data layout queries are used instead 111 // of directly inspecting the entries. 112 for (DataLayoutEntryInterface entry : attribute.getEntries()) { 113 auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()); 114 if (!type) 115 continue; 116 // Data layout for the index type is irrelevant at this point. 117 if (isa<IndexType>(type)) 118 continue; 119 layoutStream << "-"; 120 LogicalResult result = 121 llvm::TypeSwitch<Type, LogicalResult>(type) 122 .Case<IntegerType, Float16Type, Float32Type, Float64Type, 123 Float80Type, Float128Type>([&](Type type) -> LogicalResult { 124 if (auto intType = dyn_cast<IntegerType>(type)) { 125 if (intType.getSignedness() != IntegerType::Signless) 126 return emitError(*loc) 127 << "unsupported data layout for non-signless integer " 128 << intType; 129 layoutStream << "i"; 130 } else { 131 layoutStream << "f"; 132 } 133 unsigned size = dataLayout.getTypeSizeInBits(type); 134 unsigned abi = dataLayout.getTypeABIAlignment(type) * 8u; 135 unsigned preferred = 136 dataLayout.getTypePreferredAlignment(type) * 8u; 137 layoutStream << size << ":" << abi; 138 if (abi != preferred) 139 layoutStream << ":" << preferred; 140 return success(); 141 }) 142 .Case([&](LLVMPointerType ptrType) { 143 layoutStream << "p" << ptrType.getAddressSpace() << ":"; 144 unsigned size = dataLayout.getTypeSizeInBits(type); 145 unsigned abi = dataLayout.getTypeABIAlignment(type) * 8u; 146 unsigned preferred = 147 dataLayout.getTypePreferredAlignment(type) * 8u; 148 layoutStream << size << ":" << abi << ":" << preferred; 149 if (std::optional<unsigned> index = extractPointerSpecValue( 150 entry.getValue(), PtrDLEntryPos::Index)) 151 layoutStream << ":" << *index; 152 return success(); 153 }) 154 .Default([loc](Type type) { 155 return emitError(*loc) 156 << "unsupported type in data layout: " << type; 157 }); 158 if (failed(result)) 159 return failure(); 160 } 161 layoutStream.flush(); 162 StringRef layoutSpec(llvmDataLayout); 163 if (layoutSpec.startswith("-")) 164 layoutSpec = layoutSpec.drop_front(); 165 166 return llvm::DataLayout(layoutSpec); 167 } 168 169 /// Builds a constant of a sequential LLVM type `type`, potentially containing 170 /// other sequential types recursively, from the individual constant values 171 /// provided in `constants`. `shape` contains the number of elements in nested 172 /// sequential types. Reports errors at `loc` and returns nullptr on error. 173 static llvm::Constant * 174 buildSequentialConstant(ArrayRef<llvm::Constant *> &constants, 175 ArrayRef<int64_t> shape, llvm::Type *type, 176 Location loc) { 177 if (shape.empty()) { 178 llvm::Constant *result = constants.front(); 179 constants = constants.drop_front(); 180 return result; 181 } 182 183 llvm::Type *elementType; 184 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) { 185 elementType = arrayTy->getElementType(); 186 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) { 187 elementType = vectorTy->getElementType(); 188 } else { 189 emitError(loc) << "expected sequential LLVM types wrapping a scalar"; 190 return nullptr; 191 } 192 193 SmallVector<llvm::Constant *, 8> nested; 194 nested.reserve(shape.front()); 195 for (int64_t i = 0; i < shape.front(); ++i) { 196 nested.push_back(buildSequentialConstant(constants, shape.drop_front(), 197 elementType, loc)); 198 if (!nested.back()) 199 return nullptr; 200 } 201 202 if (shape.size() == 1 && type->isVectorTy()) 203 return llvm::ConstantVector::get(nested); 204 return llvm::ConstantArray::get( 205 llvm::ArrayType::get(elementType, shape.front()), nested); 206 } 207 208 /// Returns the first non-sequential type nested in sequential types. 209 static llvm::Type *getInnermostElementType(llvm::Type *type) { 210 do { 211 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) { 212 type = arrayTy->getElementType(); 213 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) { 214 type = vectorTy->getElementType(); 215 } else { 216 return type; 217 } 218 } while (true); 219 } 220 221 /// Convert a dense elements attribute to an LLVM IR constant using its raw data 222 /// storage if possible. This supports elements attributes of tensor or vector 223 /// type and avoids constructing separate objects for individual values of the 224 /// innermost dimension. Constants for other dimensions are still constructed 225 /// recursively. Returns null if constructing from raw data is not supported for 226 /// this type, e.g., element type is not a power-of-two-sized primitive. Reports 227 /// other errors at `loc`. 228 static llvm::Constant * 229 convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, 230 llvm::Type *llvmType, 231 const ModuleTranslation &moduleTranslation) { 232 if (!denseElementsAttr) 233 return nullptr; 234 235 llvm::Type *innermostLLVMType = getInnermostElementType(llvmType); 236 if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType)) 237 return nullptr; 238 239 ShapedType type = denseElementsAttr.getType(); 240 if (type.getNumElements() == 0) 241 return nullptr; 242 243 // Check that the raw data size matches what is expected for the scalar size. 244 // TODO: in theory, we could repack the data here to keep constructing from 245 // raw data. 246 // TODO: we may also need to consider endianness when cross-compiling to an 247 // architecture where it is different. 248 unsigned elementByteSize = denseElementsAttr.getRawData().size() / 249 denseElementsAttr.getNumElements(); 250 if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) 251 return nullptr; 252 253 // Compute the shape of all dimensions but the innermost. Note that the 254 // innermost dimension may be that of the vector element type. 255 bool hasVectorElementType = isa<VectorType>(type.getElementType()); 256 unsigned numAggregates = 257 denseElementsAttr.getNumElements() / 258 (hasVectorElementType ? 1 259 : denseElementsAttr.getType().getShape().back()); 260 ArrayRef<int64_t> outerShape = type.getShape(); 261 if (!hasVectorElementType) 262 outerShape = outerShape.drop_back(); 263 264 // Handle the case of vector splat, LLVM has special support for it. 265 if (denseElementsAttr.isSplat() && 266 (isa<VectorType>(type) || hasVectorElementType)) { 267 llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( 268 innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc, 269 moduleTranslation); 270 llvm::Constant *splatVector = 271 llvm::ConstantDataVector::getSplat(0, splatValue); 272 SmallVector<llvm::Constant *> constants(numAggregates, splatVector); 273 ArrayRef<llvm::Constant *> constantsRef = constants; 274 return buildSequentialConstant(constantsRef, outerShape, llvmType, loc); 275 } 276 if (denseElementsAttr.isSplat()) 277 return nullptr; 278 279 // In case of non-splat, create a constructor for the innermost constant from 280 // a piece of raw data. 281 std::function<llvm::Constant *(StringRef)> buildCstData; 282 if (isa<TensorType>(type)) { 283 auto vectorElementType = dyn_cast<VectorType>(type.getElementType()); 284 if (vectorElementType && vectorElementType.getRank() == 1) { 285 buildCstData = [&](StringRef data) { 286 return llvm::ConstantDataVector::getRaw( 287 data, vectorElementType.getShape().back(), innermostLLVMType); 288 }; 289 } else if (!vectorElementType) { 290 buildCstData = [&](StringRef data) { 291 return llvm::ConstantDataArray::getRaw(data, type.getShape().back(), 292 innermostLLVMType); 293 }; 294 } 295 } else if (isa<VectorType>(type)) { 296 buildCstData = [&](StringRef data) { 297 return llvm::ConstantDataVector::getRaw(data, type.getShape().back(), 298 innermostLLVMType); 299 }; 300 } 301 if (!buildCstData) 302 return nullptr; 303 304 // Create innermost constants and defer to the default constant creation 305 // mechanism for other dimensions. 306 SmallVector<llvm::Constant *> constants; 307 unsigned aggregateSize = denseElementsAttr.getType().getShape().back() * 308 (innermostLLVMType->getScalarSizeInBits() / 8); 309 constants.reserve(numAggregates); 310 for (unsigned i = 0; i < numAggregates; ++i) { 311 StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize, 312 aggregateSize); 313 constants.push_back(buildCstData(data)); 314 } 315 316 ArrayRef<llvm::Constant *> constantsRef = constants; 317 return buildSequentialConstant(constantsRef, outerShape, llvmType, loc); 318 } 319 320 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. 321 /// This currently supports integer, floating point, splat and dense element 322 /// attributes and combinations thereof. Also, an array attribute with two 323 /// elements is supported to represent a complex constant. In case of error, 324 /// report it to `loc` and return nullptr. 325 llvm::Constant *mlir::LLVM::detail::getLLVMConstant( 326 llvm::Type *llvmType, Attribute attr, Location loc, 327 const ModuleTranslation &moduleTranslation) { 328 if (!attr) 329 return llvm::UndefValue::get(llvmType); 330 if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { 331 auto arrayAttr = dyn_cast<ArrayAttr>(attr); 332 if (!arrayAttr || arrayAttr.size() != 2) { 333 emitError(loc, "expected struct type to be a complex number"); 334 return nullptr; 335 } 336 llvm::Type *elementType = structType->getElementType(0); 337 llvm::Constant *real = 338 getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation); 339 if (!real) 340 return nullptr; 341 llvm::Constant *imag = 342 getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation); 343 if (!imag) 344 return nullptr; 345 return llvm::ConstantStruct::get(structType, {real, imag}); 346 } 347 if (auto *targetExtType = dyn_cast<::llvm::TargetExtType>(llvmType)) { 348 // TODO: Replace with 'zeroinitializer' once there is a dedicated 349 // zeroinitializer operation in the LLVM dialect. 350 auto intAttr = dyn_cast<IntegerAttr>(attr); 351 if (!intAttr || intAttr.getInt() != 0) 352 emitError(loc, 353 "Only zero-initialization allowed for target extension type"); 354 355 return llvm::ConstantTargetNone::get(targetExtType); 356 } 357 // For integer types, we allow a mismatch in sizes as the index type in 358 // MLIR might have a different size than the index type in the LLVM module. 359 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) 360 return llvm::ConstantInt::get( 361 llvmType, 362 intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); 363 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) { 364 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); 365 // Special case for 8-bit floats, which are represented by integers due to 366 // the lack of native fp8 types in LLVM at the moment. Additionally, handle 367 // targets (like AMDGPU) that don't implement bfloat and convert all bfloats 368 // to i16. 369 unsigned floatWidth = APFloat::getSizeInBits(sem); 370 if (llvmType->isIntegerTy(floatWidth)) 371 return llvm::ConstantInt::get(llvmType, 372 floatAttr.getValue().bitcastToAPInt()); 373 if (llvmType != 374 llvm::Type::getFloatingPointTy(llvmType->getContext(), 375 floatAttr.getValue().getSemantics())) { 376 emitError(loc, "FloatAttr does not match expected type of the constant"); 377 return nullptr; 378 } 379 return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); 380 } 381 if (auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr)) 382 return llvm::ConstantExpr::getBitCast( 383 moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType); 384 if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) { 385 llvm::Type *elementType; 386 uint64_t numElements; 387 bool isScalable = false; 388 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) { 389 elementType = arrayTy->getElementType(); 390 numElements = arrayTy->getNumElements(); 391 } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) { 392 elementType = fVectorTy->getElementType(); 393 numElements = fVectorTy->getNumElements(); 394 } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) { 395 elementType = sVectorTy->getElementType(); 396 numElements = sVectorTy->getMinNumElements(); 397 isScalable = true; 398 } else { 399 llvm_unreachable("unrecognized constant vector type"); 400 } 401 // Splat value is a scalar. Extract it only if the element type is not 402 // another sequence type. The recursion terminates because each step removes 403 // one outer sequential type. 404 bool elementTypeSequential = 405 isa<llvm::ArrayType, llvm::VectorType>(elementType); 406 llvm::Constant *child = getLLVMConstant( 407 elementType, 408 elementTypeSequential ? splatAttr 409 : splatAttr.getSplatValue<Attribute>(), 410 loc, moduleTranslation); 411 if (!child) 412 return nullptr; 413 if (llvmType->isVectorTy()) 414 return llvm::ConstantVector::getSplat( 415 llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child); 416 if (llvmType->isArrayTy()) { 417 auto *arrayType = llvm::ArrayType::get(elementType, numElements); 418 SmallVector<llvm::Constant *, 8> constants(numElements, child); 419 return llvm::ConstantArray::get(arrayType, constants); 420 } 421 } 422 423 // Try using raw elements data if possible. 424 if (llvm::Constant *result = 425 convertDenseElementsAttr(loc, dyn_cast<DenseElementsAttr>(attr), 426 llvmType, moduleTranslation)) { 427 return result; 428 } 429 430 // Fall back to element-by-element construction otherwise. 431 if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) { 432 assert(elementsAttr.getShapedType().hasStaticShape()); 433 assert(!elementsAttr.getShapedType().getShape().empty() && 434 "unexpected empty elements attribute shape"); 435 436 SmallVector<llvm::Constant *, 8> constants; 437 constants.reserve(elementsAttr.getNumElements()); 438 llvm::Type *innermostType = getInnermostElementType(llvmType); 439 for (auto n : elementsAttr.getValues<Attribute>()) { 440 constants.push_back( 441 getLLVMConstant(innermostType, n, loc, moduleTranslation)); 442 if (!constants.back()) 443 return nullptr; 444 } 445 ArrayRef<llvm::Constant *> constantsRef = constants; 446 llvm::Constant *result = buildSequentialConstant( 447 constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc); 448 assert(constantsRef.empty() && "did not consume all elemental constants"); 449 return result; 450 } 451 452 if (auto stringAttr = dyn_cast<StringAttr>(attr)) { 453 return llvm::ConstantDataArray::get( 454 moduleTranslation.getLLVMContext(), 455 ArrayRef<char>{stringAttr.getValue().data(), 456 stringAttr.getValue().size()}); 457 } 458 emitError(loc, "unsupported constant value"); 459 return nullptr; 460 } 461 462 ModuleTranslation::ModuleTranslation(Operation *module, 463 std::unique_ptr<llvm::Module> llvmModule) 464 : mlirModule(module), llvmModule(std::move(llvmModule)), 465 debugTranslation( 466 std::make_unique<DebugTranslation>(module, *this->llvmModule)), 467 loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>( 468 *this, *this->llvmModule)), 469 typeTranslator(this->llvmModule->getContext()), 470 iface(module->getContext()) { 471 assert(satisfiesLLVMModule(mlirModule) && 472 "mlirModule should honor LLVM's module semantics."); 473 } 474 475 ModuleTranslation::~ModuleTranslation() { 476 if (ompBuilder) 477 ompBuilder->finalize(); 478 } 479 480 void ModuleTranslation::forgetMapping(Region ®ion) { 481 SmallVector<Region *> toProcess; 482 toProcess.push_back(®ion); 483 while (!toProcess.empty()) { 484 Region *current = toProcess.pop_back_val(); 485 for (Block &block : *current) { 486 blockMapping.erase(&block); 487 for (Value arg : block.getArguments()) 488 valueMapping.erase(arg); 489 for (Operation &op : block) { 490 for (Value value : op.getResults()) 491 valueMapping.erase(value); 492 if (op.hasSuccessors()) 493 branchMapping.erase(&op); 494 if (isa<LLVM::GlobalOp>(op)) 495 globalsMapping.erase(&op); 496 llvm::append_range( 497 toProcess, 498 llvm::map_range(op.getRegions(), [](Region &r) { return &r; })); 499 } 500 } 501 } 502 } 503 504 /// Get the SSA value passed to the current block from the terminator operation 505 /// of its predecessor. 506 static Value getPHISourceValue(Block *current, Block *pred, 507 unsigned numArguments, unsigned index) { 508 Operation &terminator = *pred->getTerminator(); 509 if (isa<LLVM::BrOp>(terminator)) 510 return terminator.getOperand(index); 511 512 #ifndef NDEBUG 513 llvm::SmallPtrSet<Block *, 4> seenSuccessors; 514 for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) { 515 Block *successor = terminator.getSuccessor(i); 516 auto branch = cast<BranchOpInterface>(terminator); 517 SuccessorOperands successorOperands = branch.getSuccessorOperands(i); 518 assert( 519 (!seenSuccessors.contains(successor) || successorOperands.empty()) && 520 "successors with arguments in LLVM branches must be different blocks"); 521 seenSuccessors.insert(successor); 522 } 523 #endif 524 525 // For instructions that branch based on a condition value, we need to take 526 // the operands for the branch that was taken. 527 if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) { 528 // For conditional branches, we take the operands from either the "true" or 529 // the "false" branch. 530 return condBranchOp.getSuccessor(0) == current 531 ? condBranchOp.getTrueDestOperands()[index] 532 : condBranchOp.getFalseDestOperands()[index]; 533 } 534 535 if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) { 536 // For switches, we take the operands from either the default case, or from 537 // the case branch that was taken. 538 if (switchOp.getDefaultDestination() == current) 539 return switchOp.getDefaultOperands()[index]; 540 for (const auto &i : llvm::enumerate(switchOp.getCaseDestinations())) 541 if (i.value() == current) 542 return switchOp.getCaseOperands(i.index())[index]; 543 } 544 545 if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) { 546 return invokeOp.getNormalDest() == current 547 ? invokeOp.getNormalDestOperands()[index] 548 : invokeOp.getUnwindDestOperands()[index]; 549 } 550 551 llvm_unreachable( 552 "only branch, switch or invoke operations can be terminators " 553 "of a block that has successors"); 554 } 555 556 /// Connect the PHI nodes to the results of preceding blocks. 557 void mlir::LLVM::detail::connectPHINodes(Region ®ion, 558 const ModuleTranslation &state) { 559 // Skip the first block, it cannot be branched to and its arguments correspond 560 // to the arguments of the LLVM function. 561 for (Block &bb : llvm::drop_begin(region)) { 562 llvm::BasicBlock *llvmBB = state.lookupBlock(&bb); 563 auto phis = llvmBB->phis(); 564 auto numArguments = bb.getNumArguments(); 565 assert(numArguments == std::distance(phis.begin(), phis.end())); 566 for (auto [index, phiNode] : llvm::enumerate(phis)) { 567 for (auto *pred : bb.getPredecessors()) { 568 // Find the LLVM IR block that contains the converted terminator 569 // instruction and use it in the PHI node. Note that this block is not 570 // necessarily the same as state.lookupBlock(pred), some operations 571 // (in particular, OpenMP operations using OpenMPIRBuilder) may have 572 // split the blocks. 573 llvm::Instruction *terminator = 574 state.lookupBranch(pred->getTerminator()); 575 assert(terminator && "missing the mapping for a terminator"); 576 phiNode.addIncoming(state.lookupValue(getPHISourceValue( 577 &bb, pred, numArguments, index)), 578 terminator->getParent()); 579 } 580 } 581 } 582 } 583 584 /// Sort function blocks topologically. 585 SetVector<Block *> 586 mlir::LLVM::detail::getTopologicallySortedBlocks(Region ®ion) { 587 // For each block that has not been visited yet (i.e. that has no 588 // predecessors), add it to the list as well as its successors. 589 SetVector<Block *> blocks; 590 for (Block &b : region) { 591 if (blocks.count(&b) == 0) { 592 llvm::ReversePostOrderTraversal<Block *> traversal(&b); 593 blocks.insert(traversal.begin(), traversal.end()); 594 } 595 } 596 assert(blocks.size() == region.getBlocks().size() && 597 "some blocks are not sorted"); 598 599 return blocks; 600 } 601 602 llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( 603 llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, 604 ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) { 605 llvm::Module *module = builder.GetInsertBlock()->getModule(); 606 llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, tys); 607 return builder.CreateCall(fn, args); 608 } 609 610 /// Given a single MLIR operation, create the corresponding LLVM IR operation 611 /// using the `builder`. 612 LogicalResult 613 ModuleTranslation::convertOperation(Operation &op, 614 llvm::IRBuilderBase &builder) { 615 const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op); 616 if (!opIface) 617 return op.emitError("cannot be converted to LLVM IR: missing " 618 "`LLVMTranslationDialectInterface` registration for " 619 "dialect for op: ") 620 << op.getName(); 621 622 if (failed(opIface->convertOperation(&op, builder, *this))) 623 return op.emitError("LLVM Translation failed for operation: ") 624 << op.getName(); 625 626 return convertDialectAttributes(&op); 627 } 628 629 /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes 630 /// to define values corresponding to the MLIR block arguments. These nodes 631 /// are not connected to the source basic blocks, which may not exist yet. Uses 632 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have 633 /// been created for `bb` and included in the block mapping. Inserts new 634 /// instructions at the end of the block and leaves `builder` in a state 635 /// suitable for further insertion into the end of the block. 636 LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments, 637 llvm::IRBuilderBase &builder) { 638 builder.SetInsertPoint(lookupBlock(&bb)); 639 auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram(); 640 641 // Before traversing operations, make block arguments available through 642 // value remapping and PHI nodes, but do not add incoming edges for the PHI 643 // nodes just yet: those values may be defined by this or following blocks. 644 // This step is omitted if "ignoreArguments" is set. The arguments of the 645 // first block have been already made available through the remapping of 646 // LLVM function arguments. 647 if (!ignoreArguments) { 648 auto predecessors = bb.getPredecessors(); 649 unsigned numPredecessors = 650 std::distance(predecessors.begin(), predecessors.end()); 651 for (auto arg : bb.getArguments()) { 652 auto wrappedType = arg.getType(); 653 if (!isCompatibleType(wrappedType)) 654 return emitError(bb.front().getLoc(), 655 "block argument does not have an LLVM type"); 656 llvm::Type *type = convertType(wrappedType); 657 llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors); 658 mapValue(arg, phi); 659 } 660 } 661 662 // Traverse operations. 663 for (auto &op : bb) { 664 // Set the current debug location within the builder. 665 builder.SetCurrentDebugLocation( 666 debugTranslation->translateLoc(op.getLoc(), subprogram)); 667 668 if (failed(convertOperation(op, builder))) 669 return failure(); 670 671 // Set the branch weight metadata on the translated instruction. 672 if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) 673 setBranchWeightsMetadata(iface); 674 } 675 676 return success(); 677 } 678 679 /// A helper method to get the single Block in an operation honoring LLVM's 680 /// module requirements. 681 static Block &getModuleBody(Operation *module) { 682 return module->getRegion(0).front(); 683 } 684 685 /// A helper method to decide if a constant must not be set as a global variable 686 /// initializer. For an external linkage variable, the variable with an 687 /// initializer is considered externally visible and defined in this module, the 688 /// variable without an initializer is externally available and is defined 689 /// elsewhere. 690 static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage, 691 llvm::Constant *cst) { 692 return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) || 693 linkage == llvm::GlobalVariable::ExternalWeakLinkage; 694 } 695 696 /// Sets the runtime preemption specifier of `gv` to dso_local if 697 /// `dsoLocalRequested` is true, otherwise it is left unchanged. 698 static void addRuntimePreemptionSpecifier(bool dsoLocalRequested, 699 llvm::GlobalValue *gv) { 700 if (dsoLocalRequested) 701 gv->setDSOLocal(true); 702 } 703 704 /// Create named global variables that correspond to llvm.mlir.global 705 /// definitions. Convert llvm.global_ctors and global_dtors ops. 706 LogicalResult ModuleTranslation::convertGlobals() { 707 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) { 708 llvm::Type *type = convertType(op.getType()); 709 llvm::Constant *cst = nullptr; 710 if (op.getValueOrNull()) { 711 // String attributes are treated separately because they cannot appear as 712 // in-function constants and are thus not supported by getLLVMConstant. 713 if (auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) { 714 cst = llvm::ConstantDataArray::getString( 715 llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false); 716 type = cst->getType(); 717 } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc(), 718 *this))) { 719 return failure(); 720 } 721 } 722 723 auto linkage = convertLinkageToLLVM(op.getLinkage()); 724 auto addrSpace = op.getAddrSpace(); 725 726 // LLVM IR requires constant with linkage other than external or weak 727 // external to have initializers. If MLIR does not provide an initializer, 728 // default to undef. 729 bool dropInitializer = shouldDropGlobalInitializer(linkage, cst); 730 if (!dropInitializer && !cst) 731 cst = llvm::UndefValue::get(type); 732 else if (dropInitializer && cst) 733 cst = nullptr; 734 735 auto *var = new llvm::GlobalVariable( 736 *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(), 737 /*InsertBefore=*/nullptr, 738 op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel 739 : llvm::GlobalValue::NotThreadLocal, 740 addrSpace); 741 742 if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) { 743 auto selectorOp = cast<ComdatSelectorOp>( 744 SymbolTable::lookupNearestSymbolFrom(op, *comdat)); 745 var->setComdat(comdatMapping.lookup(selectorOp)); 746 } 747 748 if (op.getUnnamedAddr().has_value()) 749 var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr())); 750 751 if (op.getSection().has_value()) 752 var->setSection(*op.getSection()); 753 754 addRuntimePreemptionSpecifier(op.getDsoLocal(), var); 755 756 std::optional<uint64_t> alignment = op.getAlignment(); 757 if (alignment.has_value()) 758 var->setAlignment(llvm::MaybeAlign(alignment.value())); 759 760 var->setVisibility(convertVisibilityToLLVM(op.getVisibility_())); 761 762 globalsMapping.try_emplace(op, var); 763 } 764 765 // Convert global variable bodies. This is done after all global variables 766 // have been created in LLVM IR because a global body may refer to another 767 // global or itself. So all global variables need to be mapped first. 768 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) { 769 if (Block *initializer = op.getInitializerBlock()) { 770 llvm::IRBuilder<> builder(llvmModule->getContext()); 771 for (auto &op : initializer->without_terminator()) { 772 if (failed(convertOperation(op, builder)) || 773 !isa<llvm::Constant>(lookupValue(op.getResult(0)))) 774 return emitError(op.getLoc(), "unemittable constant value"); 775 } 776 ReturnOp ret = cast<ReturnOp>(initializer->getTerminator()); 777 llvm::Constant *cst = 778 cast<llvm::Constant>(lookupValue(ret.getOperand(0))); 779 auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op)); 780 if (!shouldDropGlobalInitializer(global->getLinkage(), cst)) 781 global->setInitializer(cst); 782 } 783 } 784 785 // Convert llvm.mlir.global_ctors and dtors. 786 for (Operation &op : getModuleBody(mlirModule)) { 787 auto ctorOp = dyn_cast<GlobalCtorsOp>(op); 788 auto dtorOp = dyn_cast<GlobalDtorsOp>(op); 789 if (!ctorOp && !dtorOp) 790 continue; 791 auto range = ctorOp ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities()) 792 : llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities()); 793 auto appendGlobalFn = 794 ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors; 795 for (auto symbolAndPriority : range) { 796 llvm::Function *f = lookupFunction( 797 cast<FlatSymbolRefAttr>(std::get<0>(symbolAndPriority)).getValue()); 798 appendGlobalFn(*llvmModule, f, 799 cast<IntegerAttr>(std::get<1>(symbolAndPriority)).getInt(), 800 /*Data=*/nullptr); 801 } 802 } 803 804 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) 805 if (failed(convertDialectAttributes(op))) 806 return failure(); 807 808 return success(); 809 } 810 811 /// Attempts to add an attribute identified by `key`, optionally with the given 812 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the 813 /// attribute has a kind known to LLVM IR, create the attribute of this kind, 814 /// otherwise keep it as a string attribute. Performs additional checks for 815 /// attributes known to have or not have a value in order to avoid assertions 816 /// inside LLVM upon construction. 817 static LogicalResult checkedAddLLVMFnAttribute(Location loc, 818 llvm::Function *llvmFunc, 819 StringRef key, 820 StringRef value = StringRef()) { 821 auto kind = llvm::Attribute::getAttrKindFromName(key); 822 if (kind == llvm::Attribute::None) { 823 llvmFunc->addFnAttr(key, value); 824 return success(); 825 } 826 827 if (llvm::Attribute::isIntAttrKind(kind)) { 828 if (value.empty()) 829 return emitError(loc) << "LLVM attribute '" << key << "' expects a value"; 830 831 int64_t result; 832 if (!value.getAsInteger(/*Radix=*/0, result)) 833 llvmFunc->addFnAttr( 834 llvm::Attribute::get(llvmFunc->getContext(), kind, result)); 835 else 836 llvmFunc->addFnAttr(key, value); 837 return success(); 838 } 839 840 if (!value.empty()) 841 return emitError(loc) << "LLVM attribute '" << key 842 << "' does not expect a value, found '" << value 843 << "'"; 844 845 llvmFunc->addFnAttr(kind); 846 return success(); 847 } 848 849 /// Attaches the attributes listed in the given array attribute to `llvmFunc`. 850 /// Reports error to `loc` if any and returns immediately. Expects `attributes` 851 /// to be an array attribute containing either string attributes, treated as 852 /// value-less LLVM attributes, or array attributes containing two string 853 /// attributes, with the first string being the name of the corresponding LLVM 854 /// attribute and the second string beings its value. Note that even integer 855 /// attributes are expected to have their values expressed as strings. 856 static LogicalResult 857 forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes, 858 llvm::Function *llvmFunc) { 859 if (!attributes) 860 return success(); 861 862 for (Attribute attr : *attributes) { 863 if (auto stringAttr = dyn_cast<StringAttr>(attr)) { 864 if (failed( 865 checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue()))) 866 return failure(); 867 continue; 868 } 869 870 auto arrayAttr = dyn_cast<ArrayAttr>(attr); 871 if (!arrayAttr || arrayAttr.size() != 2) 872 return emitError(loc) 873 << "expected 'passthrough' to contain string or array attributes"; 874 875 auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]); 876 auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]); 877 if (!keyAttr || !valueAttr) 878 return emitError(loc) 879 << "expected arrays within 'passthrough' to contain two strings"; 880 881 if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(), 882 valueAttr.getValue()))) 883 return failure(); 884 } 885 return success(); 886 } 887 888 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { 889 // Clear the block, branch value mappings, they are only relevant within one 890 // function. 891 blockMapping.clear(); 892 valueMapping.clear(); 893 branchMapping.clear(); 894 llvm::Function *llvmFunc = lookupFunction(func.getName()); 895 896 // Translate the debug information for this function. 897 debugTranslation->translate(func, *llvmFunc); 898 899 // Add function arguments to the value remapping table. 900 for (auto [mlirArg, llvmArg] : 901 llvm::zip(func.getArguments(), llvmFunc->args())) 902 mapValue(mlirArg, &llvmArg); 903 904 // Check the personality and set it. 905 if (func.getPersonality()) { 906 llvm::Type *ty = llvm::Type::getInt8PtrTy(llvmFunc->getContext()); 907 if (llvm::Constant *pfunc = getLLVMConstant(ty, func.getPersonalityAttr(), 908 func.getLoc(), *this)) 909 llvmFunc->setPersonalityFn(pfunc); 910 } 911 912 if (std::optional<StringRef> section = func.getSection()) 913 llvmFunc->setSection(*section); 914 915 if (func.getArmStreaming()) 916 llvmFunc->addFnAttr("aarch64_pstate_sm_enabled"); 917 else if (func.getArmLocallyStreaming()) 918 llvmFunc->addFnAttr("aarch64_pstate_sm_body"); 919 920 // First, create all blocks so we can jump to them. 921 llvm::LLVMContext &llvmContext = llvmFunc->getContext(); 922 for (auto &bb : func) { 923 auto *llvmBB = llvm::BasicBlock::Create(llvmContext); 924 llvmBB->insertInto(llvmFunc); 925 mapBlock(&bb, llvmBB); 926 } 927 928 // Then, convert blocks one by one in topological order to ensure defs are 929 // converted before uses. 930 auto blocks = detail::getTopologicallySortedBlocks(func.getBody()); 931 for (Block *bb : blocks) { 932 llvm::IRBuilder<> builder(llvmContext); 933 if (failed(convertBlock(*bb, bb->isEntryBlock(), builder))) 934 return failure(); 935 } 936 937 // After all blocks have been traversed and values mapped, connect the PHI 938 // nodes to the results of preceding blocks. 939 detail::connectPHINodes(func.getBody(), *this); 940 941 // Finally, convert dialect attributes attached to the function. 942 return convertDialectAttributes(func); 943 } 944 945 LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) { 946 for (NamedAttribute attribute : op->getDialectAttrs()) 947 if (failed(iface.amendOperation(op, attribute, *this))) 948 return failure(); 949 return success(); 950 } 951 952 /// Converts the function attributes from LLVMFuncOp and attaches them to the 953 /// llvm::Function. 954 static void convertFunctionAttributes(LLVMFuncOp func, 955 llvm::Function *llvmFunc) { 956 if (!func.getMemory()) 957 return; 958 959 MemoryEffectsAttr memEffects = func.getMemoryAttr(); 960 961 // Add memory effects incrementally. 962 llvm::MemoryEffects newMemEffects = 963 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem, 964 convertModRefInfoToLLVM(memEffects.getArgMem())); 965 newMemEffects |= llvm::MemoryEffects( 966 llvm::MemoryEffects::Location::InaccessibleMem, 967 convertModRefInfoToLLVM(memEffects.getInaccessibleMem())); 968 newMemEffects |= 969 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other, 970 convertModRefInfoToLLVM(memEffects.getOther())); 971 llvmFunc->setMemoryEffects(newMemEffects); 972 } 973 974 llvm::AttrBuilder 975 ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) { 976 llvm::AttrBuilder attrBuilder(llvmModule->getContext()); 977 978 for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) { 979 Attribute attr = paramAttrs.get(mlirName); 980 // Skip attributes that are not present. 981 if (!attr) 982 continue; 983 984 // NOTE: C++17 does not support capturing structured bindings. 985 llvm::Attribute::AttrKind llvmKindCap = llvmKind; 986 987 llvm::TypeSwitch<Attribute>(attr) 988 .Case<TypeAttr>([&](auto typeAttr) { 989 attrBuilder.addTypeAttr(llvmKindCap, 990 convertType(typeAttr.getValue())); 991 }) 992 .Case<IntegerAttr>([&](auto intAttr) { 993 attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt()); 994 }) 995 .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); }); 996 } 997 998 return attrBuilder; 999 } 1000 1001 LogicalResult ModuleTranslation::convertFunctionSignatures() { 1002 // Declare all functions first because there may be function calls that form a 1003 // call graph with cycles, or global initializers that reference functions. 1004 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { 1005 llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( 1006 function.getName(), 1007 cast<llvm::FunctionType>(convertType(function.getFunctionType()))); 1008 llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee()); 1009 llvmFunc->setLinkage(convertLinkageToLLVM(function.getLinkage())); 1010 llvmFunc->setCallingConv(convertCConvToLLVM(function.getCConv())); 1011 mapFunction(function.getName(), llvmFunc); 1012 addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc); 1013 1014 // Convert function attributes. 1015 convertFunctionAttributes(function, llvmFunc); 1016 1017 // Convert function_entry_count attribute to metadata. 1018 if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount()) 1019 llvmFunc->setEntryCount(entryCount.value()); 1020 1021 // Convert result attributes. 1022 if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) { 1023 DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]); 1024 llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs)); 1025 } 1026 1027 // Convert argument attributes. 1028 for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) { 1029 if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) { 1030 llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs); 1031 llvmArg.addAttrs(attrBuilder); 1032 } 1033 } 1034 1035 // Forward the pass-through attributes to LLVM. 1036 if (failed(forwardPassthroughAttributes( 1037 function.getLoc(), function.getPassthrough(), llvmFunc))) 1038 return failure(); 1039 1040 // Convert visibility attribute. 1041 llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_())); 1042 1043 // Convert the comdat attribute. 1044 if (std::optional<mlir::SymbolRefAttr> comdat = function.getComdat()) { 1045 auto selectorOp = cast<ComdatSelectorOp>( 1046 SymbolTable::lookupNearestSymbolFrom(function, *comdat)); 1047 llvmFunc->setComdat(comdatMapping.lookup(selectorOp)); 1048 } 1049 1050 if (auto gc = function.getGarbageCollector()) 1051 llvmFunc->setGC(gc->str()); 1052 1053 if (auto unnamedAddr = function.getUnnamedAddr()) 1054 llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr)); 1055 1056 if (auto alignment = function.getAlignment()) 1057 llvmFunc->setAlignment(llvm::MaybeAlign(*alignment)); 1058 } 1059 1060 return success(); 1061 } 1062 1063 LogicalResult ModuleTranslation::convertFunctions() { 1064 // Convert functions. 1065 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { 1066 // Do not convert external functions, but do process dialect attributes 1067 // attached to them. 1068 if (function.isExternal()) { 1069 if (failed(convertDialectAttributes(function))) 1070 return failure(); 1071 continue; 1072 } 1073 1074 if (failed(convertOneFunction(function))) 1075 return failure(); 1076 } 1077 1078 return success(); 1079 } 1080 1081 LogicalResult ModuleTranslation::convertComdats() { 1082 for (auto comdatOp : getModuleBody(mlirModule).getOps<ComdatOp>()) { 1083 for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) { 1084 llvm::Module *module = getLLVMModule(); 1085 if (module->getComdatSymbolTable().contains(selectorOp.getSymName())) 1086 return emitError(selectorOp.getLoc()) 1087 << "comdat selection symbols must be unique even in different " 1088 "comdat regions"; 1089 llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName()); 1090 comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat())); 1091 comdatMapping.try_emplace(selectorOp, comdat); 1092 } 1093 } 1094 return success(); 1095 } 1096 1097 void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op, 1098 llvm::Instruction *inst) { 1099 if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op)) 1100 inst->setMetadata(llvm::LLVMContext::MD_access_group, node); 1101 } 1102 1103 llvm::MDNode * 1104 ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) { 1105 auto [scopeIt, scopeInserted] = 1106 aliasScopeMetadataMapping.try_emplace(aliasScopeAttr, nullptr); 1107 if (!scopeInserted) 1108 return scopeIt->second; 1109 llvm::LLVMContext &ctx = llvmModule->getContext(); 1110 // Convert the domain metadata node if necessary. 1111 auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace( 1112 aliasScopeAttr.getDomain(), nullptr); 1113 if (insertedDomain) { 1114 llvm::SmallVector<llvm::Metadata *, 2> operands; 1115 // Placeholder for self-reference. 1116 operands.push_back({}); 1117 if (StringAttr description = aliasScopeAttr.getDomain().getDescription()) 1118 operands.push_back(llvm::MDString::get(ctx, description)); 1119 domainIt->second = llvm::MDNode::get(ctx, operands); 1120 // Self-reference for uniqueness. 1121 domainIt->second->replaceOperandWith(0, domainIt->second); 1122 } 1123 // Convert the scope metadata node. 1124 assert(domainIt->second && "Scope's domain should already be valid"); 1125 llvm::SmallVector<llvm::Metadata *, 3> operands; 1126 // Placeholder for self-reference. 1127 operands.push_back({}); 1128 operands.push_back(domainIt->second); 1129 if (StringAttr description = aliasScopeAttr.getDescription()) 1130 operands.push_back(llvm::MDString::get(ctx, description)); 1131 scopeIt->second = llvm::MDNode::get(ctx, operands); 1132 // Self-reference for uniqueness. 1133 scopeIt->second->replaceOperandWith(0, scopeIt->second); 1134 return scopeIt->second; 1135 } 1136 1137 llvm::MDNode *ModuleTranslation::getOrCreateAliasScopes( 1138 ArrayRef<AliasScopeAttr> aliasScopeAttrs) { 1139 SmallVector<llvm::Metadata *> nodes; 1140 nodes.reserve(aliasScopeAttrs.size()); 1141 for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs) 1142 nodes.push_back(getOrCreateAliasScope(aliasScopeAttr)); 1143 return llvm::MDNode::get(getLLVMContext(), nodes); 1144 } 1145 1146 void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op, 1147 llvm::Instruction *inst) { 1148 auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs, unsigned kind) { 1149 if (!aliasScopeAttrs || aliasScopeAttrs.empty()) 1150 return; 1151 llvm::MDNode *node = getOrCreateAliasScopes( 1152 llvm::to_vector(aliasScopeAttrs.getAsRange<AliasScopeAttr>())); 1153 inst->setMetadata(kind, node); 1154 }; 1155 1156 populateScopeMetadata(op.getAliasScopesOrNull(), 1157 llvm::LLVMContext::MD_alias_scope); 1158 populateScopeMetadata(op.getNoAliasScopesOrNull(), 1159 llvm::LLVMContext::MD_noalias); 1160 } 1161 1162 llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const { 1163 return tbaaMetadataMapping.lookup(tbaaAttr); 1164 } 1165 1166 void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op, 1167 llvm::Instruction *inst) { 1168 ArrayAttr tagRefs = op.getTBAATagsOrNull(); 1169 if (!tagRefs || tagRefs.empty()) 1170 return; 1171 1172 // LLVM IR currently does not support attaching more than one TBAA access tag 1173 // to a memory accessing instruction. It may be useful to support this in 1174 // future, but for the time being just ignore the metadata if MLIR operation 1175 // has multiple access tags. 1176 if (tagRefs.size() > 1) { 1177 op.emitWarning() << "TBAA access tags were not translated, because LLVM " 1178 "IR only supports a single tag per instruction"; 1179 return; 1180 } 1181 1182 llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0])); 1183 inst->setMetadata(llvm::LLVMContext::MD_tbaa, node); 1184 } 1185 1186 void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) { 1187 DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull(); 1188 if (!weightsAttr) 1189 return; 1190 1191 llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op); 1192 assert(inst && "expected the operation to have a mapping to an instruction"); 1193 SmallVector<uint32_t> weights(weightsAttr.asArrayRef()); 1194 inst->setMetadata( 1195 llvm::LLVMContext::MD_prof, 1196 llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights)); 1197 } 1198 1199 LogicalResult ModuleTranslation::createTBAAMetadata() { 1200 llvm::LLVMContext &ctx = llvmModule->getContext(); 1201 llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64); 1202 1203 // Walk the entire module and create all metadata nodes for the TBAA 1204 // attributes. The code below relies on two invariants of the 1205 // `AttrTypeWalker`: 1206 // 1. Attributes are visited in post-order: Since the attributes create a DAG, 1207 // this ensures that any lookups into `tbaaMetadataMapping` for child 1208 // attributes succeed. 1209 // 2. Attributes are only ever visited once: This way we don't leak any 1210 // LLVM metadata instances. 1211 AttrTypeWalker walker; 1212 walker.addWalk([&](TBAARootAttr root) { 1213 tbaaMetadataMapping.insert( 1214 {root, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, root.getId()))}); 1215 }); 1216 1217 walker.addWalk([&](TBAATypeDescriptorAttr descriptor) { 1218 SmallVector<llvm::Metadata *> operands; 1219 operands.push_back(llvm::MDString::get(ctx, descriptor.getId())); 1220 for (TBAAMemberAttr member : descriptor.getMembers()) { 1221 operands.push_back(tbaaMetadataMapping.lookup(member.getTypeDesc())); 1222 operands.push_back(llvm::ConstantAsMetadata::get( 1223 llvm::ConstantInt::get(offsetTy, member.getOffset()))); 1224 } 1225 1226 tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)}); 1227 }); 1228 1229 walker.addWalk([&](TBAATagAttr tag) { 1230 SmallVector<llvm::Metadata *> operands; 1231 1232 operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType())); 1233 operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType())); 1234 1235 operands.push_back(llvm::ConstantAsMetadata::get( 1236 llvm::ConstantInt::get(offsetTy, tag.getOffset()))); 1237 if (tag.getConstant()) 1238 operands.push_back( 1239 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(offsetTy, 1))); 1240 1241 tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)}); 1242 }); 1243 1244 mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) { 1245 if (auto attr = analysisOpInterface.getTBAATagsOrNull()) 1246 walker.walk(attr); 1247 }); 1248 1249 return success(); 1250 } 1251 1252 void ModuleTranslation::setLoopMetadata(Operation *op, 1253 llvm::Instruction *inst) { 1254 LoopAnnotationAttr attr = 1255 TypeSwitch<Operation *, LoopAnnotationAttr>(op) 1256 .Case<LLVM::BrOp, LLVM::CondBrOp>( 1257 [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); }); 1258 if (!attr) 1259 return; 1260 llvm::MDNode *loopMD = 1261 loopAnnotationTranslation->translateLoopAnnotation(attr, op); 1262 inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD); 1263 } 1264 1265 llvm::Type *ModuleTranslation::convertType(Type type) { 1266 return typeTranslator.translateType(type); 1267 } 1268 1269 /// A helper to look up remapped operands in the value remapping table. 1270 SmallVector<llvm::Value *> ModuleTranslation::lookupValues(ValueRange values) { 1271 SmallVector<llvm::Value *> remapped; 1272 remapped.reserve(values.size()); 1273 for (Value v : values) 1274 remapped.push_back(lookupValue(v)); 1275 return remapped; 1276 } 1277 1278 llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() { 1279 if (!ompBuilder) { 1280 ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule); 1281 1282 bool isTargetDevice = false, isGPU = false; 1283 llvm::StringRef hostIRFilePath = ""; 1284 1285 if (auto deviceAttr = 1286 mlirModule->getAttrOfType<mlir::BoolAttr>("omp.is_target_device")) 1287 isTargetDevice = deviceAttr.getValue(); 1288 1289 if (auto gpuAttr = mlirModule->getAttrOfType<mlir::BoolAttr>("omp.is_gpu")) 1290 isGPU = gpuAttr.getValue(); 1291 1292 if (auto filepathAttr = 1293 mlirModule->getAttrOfType<mlir::StringAttr>("omp.host_ir_filepath")) 1294 hostIRFilePath = filepathAttr.getValue(); 1295 1296 ompBuilder->initialize(hostIRFilePath); 1297 1298 // TODO: set the flags when available 1299 llvm::OpenMPIRBuilderConfig config( 1300 isTargetDevice, isGPU, 1301 /* HasRequiresUnifiedSharedMemory */ false, 1302 /* OpenMPOffloadMandatory */ false); 1303 ompBuilder->setConfig(config); 1304 } 1305 return ompBuilder.get(); 1306 } 1307 1308 llvm::DILocation *ModuleTranslation::translateLoc(Location loc, 1309 llvm::DILocalScope *scope) { 1310 return debugTranslation->translateLoc(loc, scope); 1311 } 1312 1313 llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) { 1314 return debugTranslation->translate(attr); 1315 } 1316 1317 llvm::NamedMDNode * 1318 ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) { 1319 return llvmModule->getOrInsertNamedMetadata(name); 1320 } 1321 1322 void ModuleTranslation::StackFrame::anchor() {} 1323 1324 static std::unique_ptr<llvm::Module> 1325 prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, 1326 StringRef name) { 1327 m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>(); 1328 auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext); 1329 if (auto dataLayoutAttr = 1330 m->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) { 1331 llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue()); 1332 } else { 1333 FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout("")); 1334 if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) { 1335 if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) { 1336 llvmDataLayout = 1337 translateDataLayout(spec, DataLayout(iface), m->getLoc()); 1338 } 1339 } else if (auto mod = dyn_cast<ModuleOp>(m)) { 1340 if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) { 1341 llvmDataLayout = 1342 translateDataLayout(spec, DataLayout(mod), m->getLoc()); 1343 } 1344 } 1345 if (failed(llvmDataLayout)) 1346 return nullptr; 1347 llvmModule->setDataLayout(*llvmDataLayout); 1348 } 1349 if (auto targetTripleAttr = 1350 m->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) 1351 llvmModule->setTargetTriple(cast<StringAttr>(targetTripleAttr).getValue()); 1352 1353 // Inject declarations for `malloc` and `free` functions that can be used in 1354 // memref allocation/deallocation coming from standard ops lowering. 1355 llvm::IRBuilder<> builder(llvmContext); 1356 llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(), 1357 builder.getInt64Ty()); 1358 llvmModule->getOrInsertFunction("free", builder.getVoidTy(), 1359 builder.getInt8PtrTy()); 1360 1361 return llvmModule; 1362 } 1363 1364 std::unique_ptr<llvm::Module> 1365 mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, 1366 StringRef name) { 1367 if (!satisfiesLLVMModule(module)) { 1368 module->emitOpError("can not be translated to an LLVMIR module"); 1369 return nullptr; 1370 } 1371 1372 std::unique_ptr<llvm::Module> llvmModule = 1373 prepareLLVMModule(module, llvmContext, name); 1374 if (!llvmModule) 1375 return nullptr; 1376 1377 LLVM::ensureDistinctSuccessors(module); 1378 1379 ModuleTranslation translator(module, std::move(llvmModule)); 1380 llvm::IRBuilder<> llvmBuilder(llvmContext); 1381 1382 // Convert module before functions and operations inside, so dialect 1383 // attributes can be used to change dialect-specific global configurations via 1384 // `amendOperation()`. These configurations can then influence the translation 1385 // of operations afterwards. 1386 if (failed(translator.convertOperation(*module, llvmBuilder))) 1387 return nullptr; 1388 1389 if (failed(translator.convertComdats())) 1390 return nullptr; 1391 if (failed(translator.convertFunctionSignatures())) 1392 return nullptr; 1393 if (failed(translator.convertGlobals())) 1394 return nullptr; 1395 if (failed(translator.createTBAAMetadata())) 1396 return nullptr; 1397 1398 // Convert other top-level operations if possible. 1399 for (Operation &o : getModuleBody(module).getOperations()) { 1400 if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp, 1401 LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) && 1402 !o.hasTrait<OpTrait::IsTerminator>() && 1403 failed(translator.convertOperation(o, llvmBuilder))) { 1404 return nullptr; 1405 } 1406 } 1407 1408 // Operations in function bodies with symbolic references must be converted 1409 // after the top-level operations they refer to are declared, so we do it 1410 // last. 1411 if (failed(translator.convertFunctions())) 1412 return nullptr; 1413 1414 if (llvm::verifyModule(*translator.llvmModule, &llvm::errs())) 1415 return nullptr; 1416 1417 return std::move(translator.llvmModule); 1418 } 1419