1 //===- LLVMToLLVMIRTranslation.cpp - Translate LLVM dialect to LLVM IR ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR LLVM dialect and LLVM IR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/Support/LLVM.h" 17 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 18 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/InlineAsm.h" 21 #include "llvm/IR/MDBuilder.h" 22 #include "llvm/IR/MatrixBuilder.h" 23 #include "llvm/IR/Operator.h" 24 25 using namespace mlir; 26 using namespace mlir::LLVM; 27 using mlir::LLVM::detail::getLLVMConstant; 28 29 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc" 30 31 static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) { 32 using llvmFMF = llvm::FastMathFlags; 33 using FuncT = void (llvmFMF::*)(bool); 34 const std::pair<FastmathFlags, FuncT> handlers[] = { 35 // clang-format off 36 {FastmathFlags::nnan, &llvmFMF::setNoNaNs}, 37 {FastmathFlags::ninf, &llvmFMF::setNoInfs}, 38 {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros}, 39 {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal}, 40 {FastmathFlags::contract, &llvmFMF::setAllowContract}, 41 {FastmathFlags::afn, &llvmFMF::setApproxFunc}, 42 {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc}, 43 // clang-format on 44 }; 45 llvm::FastMathFlags ret; 46 ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue(); 47 for (auto it : handlers) 48 if (bitEnumContainsAll(fmfMlir, it.first)) 49 (ret.*(it.second))(true); 50 return ret; 51 } 52 53 /// Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices. 54 static SmallVector<unsigned> extractPosition(ArrayRef<int64_t> indices) { 55 SmallVector<unsigned> position; 56 llvm::append_range(position, indices); 57 return position; 58 } 59 60 /// Convert an LLVM type to a string for printing in diagnostics. 61 static std::string diagStr(const llvm::Type *type) { 62 std::string str; 63 llvm::raw_string_ostream os(str); 64 type->print(os); 65 return str; 66 } 67 68 /// Get the declaration of an overloaded llvm intrinsic. First we get the 69 /// overloaded argument types and/or result type from the CallIntrinsicOp, and 70 /// then use those to get the correct declaration of the overloaded intrinsic. 71 static FailureOr<llvm::Function *> 72 getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id, 73 llvm::Module *module, 74 LLVM::ModuleTranslation &moduleTranslation) { 75 SmallVector<llvm::Type *, 8> allArgTys; 76 for (Type type : op->getOperandTypes()) 77 allArgTys.push_back(moduleTranslation.convertType(type)); 78 79 llvm::Type *resTy; 80 if (op.getNumResults() == 0) 81 resTy = llvm::Type::getVoidTy(module->getContext()); 82 else 83 resTy = moduleTranslation.convertType(op.getResult(0).getType()); 84 85 // ATM we do not support variadic intrinsics. 86 llvm::FunctionType *ft = llvm::FunctionType::get(resTy, allArgTys, false); 87 88 SmallVector<llvm::Intrinsic::IITDescriptor, 8> table; 89 getIntrinsicInfoTableEntries(id, table); 90 ArrayRef<llvm::Intrinsic::IITDescriptor> tableRef = table; 91 92 SmallVector<llvm::Type *, 8> overloadedArgTys; 93 if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef, 94 overloadedArgTys) != 95 llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) { 96 return mlir::emitError(op.getLoc(), "call intrinsic signature ") 97 << diagStr(ft) << " to overloaded intrinsic " << op.getIntrinAttr() 98 << " does not match any of the overloads"; 99 } 100 101 ArrayRef<llvm::Type *> overloadedArgTysRef = overloadedArgTys; 102 return llvm::Intrinsic::getOrInsertDeclaration(module, id, 103 overloadedArgTysRef); 104 } 105 106 static llvm::OperandBundleDef 107 convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag, 108 LLVM::ModuleTranslation &moduleTranslation) { 109 std::vector<llvm::Value *> operands; 110 operands.reserve(bundleOperands.size()); 111 for (Value bundleArg : bundleOperands) 112 operands.push_back(moduleTranslation.lookupValue(bundleArg)); 113 return llvm::OperandBundleDef(bundleTag.str(), std::move(operands)); 114 } 115 116 static SmallVector<llvm::OperandBundleDef> 117 convertOperandBundles(OperandRangeRange bundleOperands, ArrayAttr bundleTags, 118 LLVM::ModuleTranslation &moduleTranslation) { 119 SmallVector<llvm::OperandBundleDef> bundles; 120 bundles.reserve(bundleOperands.size()); 121 122 for (auto [operands, tagAttr] : llvm::zip_equal(bundleOperands, bundleTags)) { 123 StringRef tag = cast<StringAttr>(tagAttr).getValue(); 124 bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation)); 125 } 126 return bundles; 127 } 128 129 static SmallVector<llvm::OperandBundleDef> 130 convertOperandBundles(OperandRangeRange bundleOperands, 131 std::optional<ArrayAttr> bundleTags, 132 LLVM::ModuleTranslation &moduleTranslation) { 133 if (!bundleTags) 134 return {}; 135 return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation); 136 } 137 138 /// Builder for LLVM_CallIntrinsicOp 139 static LogicalResult 140 convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, 141 LLVM::ModuleTranslation &moduleTranslation) { 142 llvm::Module *module = builder.GetInsertBlock()->getModule(); 143 llvm::Intrinsic::ID id = 144 llvm::Intrinsic::lookupIntrinsicID(op.getIntrinAttr()); 145 if (!id) 146 return mlir::emitError(op.getLoc(), "could not find LLVM intrinsic: ") 147 << op.getIntrinAttr(); 148 149 llvm::Function *fn = nullptr; 150 if (llvm::Intrinsic::isOverloaded(id)) { 151 auto fnOrFailure = 152 getOverloadedDeclaration(op, id, module, moduleTranslation); 153 if (failed(fnOrFailure)) 154 return failure(); 155 fn = *fnOrFailure; 156 } else { 157 fn = llvm::Intrinsic::getOrInsertDeclaration(module, id, {}); 158 } 159 160 // Check the result type of the call. 161 const llvm::Type *intrinType = 162 op.getNumResults() == 0 163 ? llvm::Type::getVoidTy(module->getContext()) 164 : moduleTranslation.convertType(op.getResultTypes().front()); 165 if (intrinType != fn->getReturnType()) { 166 return mlir::emitError(op.getLoc(), "intrinsic call returns ") 167 << diagStr(intrinType) << " but " << op.getIntrinAttr() 168 << " actually returns " << diagStr(fn->getReturnType()); 169 } 170 171 // Check the argument types of the call. If the function is variadic, check 172 // the subrange of required arguments. 173 if (!fn->getFunctionType()->isVarArg() && 174 op.getArgs().size() != fn->arg_size()) { 175 return mlir::emitError(op.getLoc(), "intrinsic call has ") 176 << op.getArgs().size() << " operands but " << op.getIntrinAttr() 177 << " expects " << fn->arg_size(); 178 } 179 if (fn->getFunctionType()->isVarArg() && 180 op.getArgs().size() < fn->arg_size()) { 181 return mlir::emitError(op.getLoc(), "intrinsic call has ") 182 << op.getArgs().size() << " operands but variadic " 183 << op.getIntrinAttr() << " expects at least " << fn->arg_size(); 184 } 185 // Check the arguments up to the number the function requires. 186 for (unsigned i = 0, e = fn->arg_size(); i != e; ++i) { 187 const llvm::Type *expected = fn->getArg(i)->getType(); 188 const llvm::Type *actual = 189 moduleTranslation.convertType(op.getOperandTypes()[i]); 190 if (actual != expected) { 191 return mlir::emitError(op.getLoc(), "intrinsic call operand #") 192 << i << " has type " << diagStr(actual) << " but " 193 << op.getIntrinAttr() << " expects " << diagStr(expected); 194 } 195 } 196 197 FastmathFlagsInterface itf = op; 198 builder.setFastMathFlags(getFastmathFlags(itf)); 199 200 auto *inst = builder.CreateCall( 201 fn, moduleTranslation.lookupValues(op.getArgs()), 202 convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(), 203 moduleTranslation)); 204 if (op.getNumResults() == 1) 205 moduleTranslation.mapValue(op->getResults().front()) = inst; 206 return success(); 207 } 208 209 static void convertLinkerOptionsOp(ArrayAttr options, 210 llvm::IRBuilderBase &builder, 211 LLVM::ModuleTranslation &moduleTranslation) { 212 llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); 213 llvm::LLVMContext &context = llvmModule->getContext(); 214 llvm::NamedMDNode *linkerMDNode = 215 llvmModule->getOrInsertNamedMetadata("llvm.linker.options"); 216 SmallVector<llvm::Metadata *> MDNodes; 217 MDNodes.reserve(options.size()); 218 for (auto s : options.getAsRange<StringAttr>()) { 219 auto *MDNode = llvm::MDString::get(context, s.getValue()); 220 MDNodes.push_back(MDNode); 221 } 222 223 auto *listMDNode = llvm::MDTuple::get(context, MDNodes); 224 linkerMDNode->addOperand(listMDNode); 225 } 226 227 static LogicalResult 228 convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, 229 LLVM::ModuleTranslation &moduleTranslation) { 230 231 llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder); 232 if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst)) 233 builder.setFastMathFlags(getFastmathFlags(fmf)); 234 235 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc" 236 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicConversions.inc" 237 238 // Emit function calls. If the "callee" attribute is present, this is a 239 // direct function call and we also need to look up the remapped function 240 // itself. Otherwise, this is an indirect call and the callee is the first 241 // operand, look it up as a normal value. 242 if (auto callOp = dyn_cast<LLVM::CallOp>(opInst)) { 243 auto operands = moduleTranslation.lookupValues(callOp.getCalleeOperands()); 244 SmallVector<llvm::OperandBundleDef> opBundles = 245 convertOperandBundles(callOp.getOpBundleOperands(), 246 callOp.getOpBundleTags(), moduleTranslation); 247 ArrayRef<llvm::Value *> operandsRef(operands); 248 llvm::CallInst *call; 249 if (auto attr = callOp.getCalleeAttr()) { 250 call = 251 builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()), 252 operandsRef, opBundles); 253 } else { 254 llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>( 255 moduleTranslation.convertType(callOp.getCalleeFunctionType())); 256 call = builder.CreateCall(calleeType, operandsRef.front(), 257 operandsRef.drop_front(), opBundles); 258 } 259 call->setCallingConv(convertCConvToLLVM(callOp.getCConv())); 260 call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind())); 261 if (callOp.getConvergentAttr()) 262 call->addFnAttr(llvm::Attribute::Convergent); 263 if (callOp.getNoUnwindAttr()) 264 call->addFnAttr(llvm::Attribute::NoUnwind); 265 if (callOp.getWillReturnAttr()) 266 call->addFnAttr(llvm::Attribute::WillReturn); 267 268 if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) { 269 llvm::MemoryEffects memEffects = 270 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem, 271 convertModRefInfoToLLVM(memAttr.getArgMem())) | 272 llvm::MemoryEffects( 273 llvm::MemoryEffects::Location::InaccessibleMem, 274 convertModRefInfoToLLVM(memAttr.getInaccessibleMem())) | 275 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other, 276 convertModRefInfoToLLVM(memAttr.getOther())); 277 call->setMemoryEffects(memEffects); 278 } 279 280 moduleTranslation.setAccessGroupsMetadata(callOp, call); 281 moduleTranslation.setAliasScopeMetadata(callOp, call); 282 moduleTranslation.setTBAAMetadata(callOp, call); 283 // If the called function has a result, remap the corresponding value. Note 284 // that LLVM IR dialect CallOp has either 0 or 1 result. 285 if (opInst.getNumResults() != 0) 286 moduleTranslation.mapValue(opInst.getResult(0), call); 287 // Check that LLVM call returns void for 0-result functions. 288 else if (!call->getType()->isVoidTy()) 289 return failure(); 290 moduleTranslation.mapCall(callOp, call); 291 return success(); 292 } 293 294 if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) { 295 // TODO: refactor function type creation which usually occurs in std-LLVM 296 // conversion. 297 SmallVector<Type, 8> operandTypes; 298 llvm::append_range(operandTypes, inlineAsmOp.getOperands().getTypes()); 299 300 Type resultType; 301 if (inlineAsmOp.getNumResults() == 0) { 302 resultType = LLVM::LLVMVoidType::get(&moduleTranslation.getContext()); 303 } else { 304 assert(inlineAsmOp.getNumResults() == 1); 305 resultType = inlineAsmOp.getResultTypes()[0]; 306 } 307 auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes); 308 llvm::InlineAsm *inlineAsmInst = 309 inlineAsmOp.getAsmDialect() 310 ? llvm::InlineAsm::get( 311 static_cast<llvm::FunctionType *>( 312 moduleTranslation.convertType(ft)), 313 inlineAsmOp.getAsmString(), inlineAsmOp.getConstraints(), 314 inlineAsmOp.getHasSideEffects(), 315 inlineAsmOp.getIsAlignStack(), 316 convertAsmDialectToLLVM(*inlineAsmOp.getAsmDialect())) 317 : llvm::InlineAsm::get(static_cast<llvm::FunctionType *>( 318 moduleTranslation.convertType(ft)), 319 inlineAsmOp.getAsmString(), 320 inlineAsmOp.getConstraints(), 321 inlineAsmOp.getHasSideEffects(), 322 inlineAsmOp.getIsAlignStack()); 323 llvm::CallInst *inst = builder.CreateCall( 324 inlineAsmInst, 325 moduleTranslation.lookupValues(inlineAsmOp.getOperands())); 326 if (auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) { 327 llvm::AttributeList attrList; 328 for (const auto &it : llvm::enumerate(*maybeOperandAttrs)) { 329 Attribute attr = it.value(); 330 if (!attr) 331 continue; 332 DictionaryAttr dAttr = cast<DictionaryAttr>(attr); 333 TypeAttr tAttr = 334 cast<TypeAttr>(dAttr.get(InlineAsmOp::getElementTypeAttrName())); 335 llvm::AttrBuilder b(moduleTranslation.getLLVMContext()); 336 llvm::Type *ty = moduleTranslation.convertType(tAttr.getValue()); 337 b.addTypeAttr(llvm::Attribute::ElementType, ty); 338 // shift to account for the returned value (this is always 1 aggregate 339 // value in LLVM). 340 int shift = (opInst.getNumResults() > 0) ? 1 : 0; 341 attrList = attrList.addAttributesAtIndex( 342 moduleTranslation.getLLVMContext(), it.index() + shift, b); 343 } 344 inst->setAttributes(attrList); 345 } 346 347 if (opInst.getNumResults() != 0) 348 moduleTranslation.mapValue(opInst.getResult(0), inst); 349 return success(); 350 } 351 352 if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) { 353 auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands()); 354 SmallVector<llvm::OperandBundleDef> opBundles = 355 convertOperandBundles(invOp.getOpBundleOperands(), 356 invOp.getOpBundleTags(), moduleTranslation); 357 ArrayRef<llvm::Value *> operandsRef(operands); 358 llvm::InvokeInst *result; 359 if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) { 360 result = builder.CreateInvoke( 361 moduleTranslation.lookupFunction(attr.getValue()), 362 moduleTranslation.lookupBlock(invOp.getSuccessor(0)), 363 moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef, 364 opBundles); 365 } else { 366 llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>( 367 moduleTranslation.convertType(invOp.getCalleeFunctionType())); 368 result = builder.CreateInvoke( 369 calleeType, operandsRef.front(), 370 moduleTranslation.lookupBlock(invOp.getSuccessor(0)), 371 moduleTranslation.lookupBlock(invOp.getSuccessor(1)), 372 operandsRef.drop_front(), opBundles); 373 } 374 result->setCallingConv(convertCConvToLLVM(invOp.getCConv())); 375 moduleTranslation.mapBranch(invOp, result); 376 // InvokeOp can only have 0 or 1 result 377 if (invOp->getNumResults() != 0) { 378 moduleTranslation.mapValue(opInst.getResult(0), result); 379 return success(); 380 } 381 return success(result->getType()->isVoidTy()); 382 } 383 384 if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) { 385 llvm::Type *ty = moduleTranslation.convertType(lpOp.getType()); 386 llvm::LandingPadInst *lpi = 387 builder.CreateLandingPad(ty, lpOp.getNumOperands()); 388 lpi->setCleanup(lpOp.getCleanup()); 389 390 // Add clauses 391 for (llvm::Value *operand : 392 moduleTranslation.lookupValues(lpOp.getOperands())) { 393 // All operands should be constant - checked by verifier 394 if (auto *constOperand = dyn_cast<llvm::Constant>(operand)) 395 lpi->addClause(constOperand); 396 } 397 moduleTranslation.mapValue(lpOp.getResult(), lpi); 398 return success(); 399 } 400 401 // Emit branches. We need to look up the remapped blocks and ignore the 402 // block arguments that were transformed into PHI nodes. 403 if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) { 404 llvm::BranchInst *branch = 405 builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor())); 406 moduleTranslation.mapBranch(&opInst, branch); 407 moduleTranslation.setLoopMetadata(&opInst, branch); 408 return success(); 409 } 410 if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) { 411 llvm::BranchInst *branch = builder.CreateCondBr( 412 moduleTranslation.lookupValue(condbrOp.getOperand(0)), 413 moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)), 414 moduleTranslation.lookupBlock(condbrOp.getSuccessor(1))); 415 moduleTranslation.mapBranch(&opInst, branch); 416 moduleTranslation.setLoopMetadata(&opInst, branch); 417 return success(); 418 } 419 if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) { 420 llvm::SwitchInst *switchInst = builder.CreateSwitch( 421 moduleTranslation.lookupValue(switchOp.getValue()), 422 moduleTranslation.lookupBlock(switchOp.getDefaultDestination()), 423 switchOp.getCaseDestinations().size()); 424 425 // Handle switch with zero cases. 426 if (!switchOp.getCaseValues()) 427 return success(); 428 429 auto *ty = llvm::cast<llvm::IntegerType>( 430 moduleTranslation.convertType(switchOp.getValue().getType())); 431 for (auto i : 432 llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()), 433 switchOp.getCaseDestinations())) 434 switchInst->addCase( 435 llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()), 436 moduleTranslation.lookupBlock(std::get<1>(i))); 437 438 moduleTranslation.mapBranch(&opInst, switchInst); 439 return success(); 440 } 441 442 // Emit addressof. We need to look up the global value referenced by the 443 // operation and store it in the MLIR-to-LLVM value mapping. This does not 444 // emit any LLVM instruction. 445 if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) { 446 LLVM::GlobalOp global = 447 addressOfOp.getGlobal(moduleTranslation.symbolTable()); 448 LLVM::LLVMFuncOp function = 449 addressOfOp.getFunction(moduleTranslation.symbolTable()); 450 451 // The verifier should not have allowed this. 452 assert((global || function) && 453 "referencing an undefined global or function"); 454 455 moduleTranslation.mapValue( 456 addressOfOp.getResult(), 457 global ? moduleTranslation.lookupGlobal(global) 458 : moduleTranslation.lookupFunction(function.getName())); 459 return success(); 460 } 461 462 return failure(); 463 } 464 465 namespace { 466 /// Implementation of the dialect interface that converts operations belonging 467 /// to the LLVM dialect to LLVM IR. 468 class LLVMDialectLLVMIRTranslationInterface 469 : public LLVMTranslationDialectInterface { 470 public: 471 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 472 473 /// Translates the given operation to LLVM IR using the provided IR builder 474 /// and saving the state in `moduleTranslation`. 475 LogicalResult 476 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 477 LLVM::ModuleTranslation &moduleTranslation) const final { 478 return convertOperationImpl(*op, builder, moduleTranslation); 479 } 480 }; 481 } // namespace 482 483 void mlir::registerLLVMDialectTranslation(DialectRegistry ®istry) { 484 registry.insert<LLVM::LLVMDialect>(); 485 registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { 486 dialect->addInterfaces<LLVMDialectLLVMIRTranslationInterface>(); 487 }); 488 } 489 490 void mlir::registerLLVMDialectTranslation(MLIRContext &context) { 491 DialectRegistry registry; 492 registerLLVMDialectTranslation(registry); 493 context.appendDialectRegistry(registry); 494 } 495