1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert MLIR Func and builtin dialects 10 // into the LLVM IR dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 15 16 #include "mlir/Analysis/DataLayoutAnalysis.h" 17 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 18 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 20 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 21 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 22 #include "mlir/Conversion/LLVMCommon/Pattern.h" 23 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 24 #include "mlir/Dialect/Func/IR/FuncOps.h" 25 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 27 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 28 #include "mlir/Dialect/Utils/StaticValueUtils.h" 29 #include "mlir/IR/Attributes.h" 30 #include "mlir/IR/Builders.h" 31 #include "mlir/IR/BuiltinAttributeInterfaces.h" 32 #include "mlir/IR/BuiltinAttributes.h" 33 #include "mlir/IR/BuiltinOps.h" 34 #include "mlir/IR/IRMapping.h" 35 #include "mlir/IR/PatternMatch.h" 36 #include "mlir/IR/SymbolTable.h" 37 #include "mlir/IR/TypeUtilities.h" 38 #include "mlir/Transforms/DialectConversion.h" 39 #include "mlir/Transforms/Passes.h" 40 #include "llvm/ADT/SmallVector.h" 41 #include "llvm/ADT/TypeSwitch.h" 42 #include "llvm/IR/DerivedTypes.h" 43 #include "llvm/IR/IRBuilder.h" 44 #include "llvm/IR/Type.h" 45 #include "llvm/Support/Casting.h" 46 #include "llvm/Support/CommandLine.h" 47 #include "llvm/Support/FormatVariadic.h" 48 #include <algorithm> 49 #include <functional> 50 #include <optional> 51 52 namespace mlir { 53 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS 54 #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS 55 #include "mlir/Conversion/Passes.h.inc" 56 } // namespace mlir 57 58 using namespace mlir; 59 60 #define PASS_NAME "convert-func-to-llvm" 61 62 static constexpr StringRef varargsAttrName = "func.varargs"; 63 static constexpr StringRef linkageAttrName = "llvm.linkage"; 64 static constexpr StringRef barePtrAttrName = "llvm.bareptr"; 65 66 /// Return `true` if the `op` should use bare pointer calling convention. 67 static bool shouldUseBarePtrCallConv(Operation *op, 68 const LLVMTypeConverter *typeConverter) { 69 return (op && op->hasAttr(barePtrAttrName)) || 70 typeConverter->getOptions().useBarePtrCallConv; 71 } 72 73 /// Only retain those attributes that are not constructed by 74 /// `LLVMFuncOp::build`. 75 static void filterFuncAttributes(FunctionOpInterface func, 76 SmallVectorImpl<NamedAttribute> &result) { 77 for (const NamedAttribute &attr : func->getDiscardableAttrs()) { 78 if (attr.getName() == linkageAttrName || 79 attr.getName() == varargsAttrName || 80 attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName()) 81 continue; 82 result.push_back(attr); 83 } 84 } 85 86 /// Propagate argument/results attributes. 87 static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType, 88 FunctionOpInterface funcOp, 89 LLVM::LLVMFuncOp wrapperFuncOp) { 90 auto argAttrs = funcOp.getAllArgAttrs(); 91 if (!resultStructType) { 92 if (auto resAttrs = funcOp.getAllResultAttrs()) 93 wrapperFuncOp.setAllResultAttrs(resAttrs); 94 if (argAttrs) 95 wrapperFuncOp.setAllArgAttrs(argAttrs); 96 } else { 97 SmallVector<Attribute> argAttributes; 98 // Only modify the argument and result attributes when the result is now 99 // an argument. 100 if (argAttrs) { 101 argAttributes.push_back(builder.getDictionaryAttr({})); 102 argAttributes.append(argAttrs.begin(), argAttrs.end()); 103 wrapperFuncOp.setAllArgAttrs(argAttributes); 104 } 105 } 106 cast<FunctionOpInterface>(wrapperFuncOp.getOperation()) 107 .setVisibility(funcOp.getVisibility()); 108 } 109 110 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct 111 /// arguments instead of unpacked arguments. This function can be called from C 112 /// by passing a pointer to a C struct corresponding to a memref descriptor. 113 /// Similarly, returned memrefs are passed via pointers to a C struct that is 114 /// passed as additional argument. 115 /// Internally, the auxiliary function unpacks the descriptor into individual 116 /// components and forwards them to `newFuncOp` and forwards the results to 117 /// the extra arguments. 118 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, 119 const LLVMTypeConverter &typeConverter, 120 FunctionOpInterface funcOp, 121 LLVM::LLVMFuncOp newFuncOp) { 122 auto type = cast<FunctionType>(funcOp.getFunctionType()); 123 auto [wrapperFuncType, resultStructType] = 124 typeConverter.convertFunctionTypeCWrapper(type); 125 126 SmallVector<NamedAttribute> attributes; 127 filterFuncAttributes(funcOp, attributes); 128 129 auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 130 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), 131 wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false, 132 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); 133 propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp); 134 135 OpBuilder::InsertionGuard guard(rewriter); 136 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter)); 137 138 SmallVector<Value, 8> args; 139 size_t argOffset = resultStructType ? 1 : 0; 140 for (auto [index, argType] : llvm::enumerate(type.getInputs())) { 141 Value arg = wrapperFuncOp.getArgument(index + argOffset); 142 if (auto memrefType = dyn_cast<MemRefType>(argType)) { 143 Value loaded = rewriter.create<LLVM::LoadOp>( 144 loc, typeConverter.convertType(memrefType), arg); 145 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); 146 continue; 147 } 148 if (isa<UnrankedMemRefType>(argType)) { 149 Value loaded = rewriter.create<LLVM::LoadOp>( 150 loc, typeConverter.convertType(argType), arg); 151 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); 152 continue; 153 } 154 155 args.push_back(arg); 156 } 157 158 auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args); 159 160 if (resultStructType) { 161 rewriter.create<LLVM::StoreOp>(loc, call.getResult(), 162 wrapperFuncOp.getArgument(0)); 163 rewriter.create<LLVM::ReturnOp>(loc, ValueRange{}); 164 } else { 165 rewriter.create<LLVM::ReturnOp>(loc, call.getResults()); 166 } 167 } 168 169 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct 170 /// arguments instead of unpacked arguments. Creates a body for the (external) 171 /// `newFuncOp` that allocates a memref descriptor on stack, packs the 172 /// individual arguments into this descriptor and passes a pointer to it into 173 /// the auxiliary function. If the result of the function cannot be directly 174 /// returned, we write it to a special first argument that provides a pointer 175 /// to a corresponding struct. This auxiliary external function is now 176 /// compatible with functions defined in C using pointers to C structs 177 /// corresponding to a memref descriptor. 178 static void wrapExternalFunction(OpBuilder &builder, Location loc, 179 const LLVMTypeConverter &typeConverter, 180 FunctionOpInterface funcOp, 181 LLVM::LLVMFuncOp newFuncOp) { 182 OpBuilder::InsertionGuard guard(builder); 183 184 auto [wrapperType, resultStructType] = 185 typeConverter.convertFunctionTypeCWrapper( 186 cast<FunctionType>(funcOp.getFunctionType())); 187 // This conversion can only fail if it could not convert one of the argument 188 // types. But since it has been applied to a non-wrapper function before, it 189 // should have failed earlier and not reach this point at all. 190 assert(wrapperType && "unexpected type conversion failure"); 191 192 SmallVector<NamedAttribute, 4> attributes; 193 filterFuncAttributes(funcOp, attributes); 194 195 // Create the auxiliary function. 196 auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>( 197 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), 198 wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false, 199 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); 200 propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc); 201 202 // The wrapper that we synthetize here should only be visible in this module. 203 newFuncOp.setLinkage(LLVM::Linkage::Private); 204 builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder)); 205 206 // Get a ValueRange containing arguments. 207 FunctionType type = cast<FunctionType>(funcOp.getFunctionType()); 208 SmallVector<Value, 8> args; 209 args.reserve(type.getNumInputs()); 210 ValueRange wrapperArgsRange(newFuncOp.getArguments()); 211 212 if (resultStructType) { 213 // Allocate the struct on the stack and pass the pointer. 214 Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0); 215 Value one = builder.create<LLVM::ConstantOp>( 216 loc, typeConverter.convertType(builder.getIndexType()), 217 builder.getIntegerAttr(builder.getIndexType(), 1)); 218 Value result = 219 builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one); 220 args.push_back(result); 221 } 222 223 // Iterate over the inputs of the original function and pack values into 224 // memref descriptors if the original type is a memref. 225 for (Type input : type.getInputs()) { 226 Value arg; 227 int numToDrop = 1; 228 auto memRefType = dyn_cast<MemRefType>(input); 229 auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input); 230 if (memRefType || unrankedMemRefType) { 231 numToDrop = memRefType 232 ? MemRefDescriptor::getNumUnpackedValues(memRefType) 233 : UnrankedMemRefDescriptor::getNumUnpackedValues(); 234 Value packed = 235 memRefType 236 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType, 237 wrapperArgsRange.take_front(numToDrop)) 238 : UnrankedMemRefDescriptor::pack( 239 builder, loc, typeConverter, unrankedMemRefType, 240 wrapperArgsRange.take_front(numToDrop)); 241 242 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); 243 Value one = builder.create<LLVM::ConstantOp>( 244 loc, typeConverter.convertType(builder.getIndexType()), 245 builder.getIntegerAttr(builder.getIndexType(), 1)); 246 Value allocated = builder.create<LLVM::AllocaOp>( 247 loc, ptrTy, packed.getType(), one, /*alignment=*/0); 248 builder.create<LLVM::StoreOp>(loc, packed, allocated); 249 arg = allocated; 250 } else { 251 arg = wrapperArgsRange[0]; 252 } 253 254 args.push_back(arg); 255 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop); 256 } 257 assert(wrapperArgsRange.empty() && "did not map some of the arguments"); 258 259 auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args); 260 261 if (resultStructType) { 262 Value result = 263 builder.create<LLVM::LoadOp>(loc, resultStructType, args.front()); 264 builder.create<LLVM::ReturnOp>(loc, result); 265 } else { 266 builder.create<LLVM::ReturnOp>(loc, call.getResults()); 267 } 268 } 269 270 /// Inserts `llvm.load` ops in the function body to restore the expected pointee 271 /// value from `llvm.byval`/`llvm.byref` function arguments that were converted 272 /// to LLVM pointer types. 273 static void restoreByValRefArgumentType( 274 ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, 275 ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs, 276 ArrayRef<BlockArgument> oldBlockArgs, LLVM::LLVMFuncOp funcOp) { 277 // Nothing to do for function declarations. 278 if (funcOp.isExternal()) 279 return; 280 281 ConversionPatternRewriter::InsertionGuard guard(rewriter); 282 rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front()); 283 284 for (const auto &[arg, oldArg, byValRefAttr] : 285 llvm::zip(funcOp.getArguments(), oldBlockArgs, byValRefNonPtrAttrs)) { 286 // Skip argument if no `llvm.byval` or `llvm.byref` attribute. 287 if (!byValRefAttr) 288 continue; 289 290 // Insert load to retrieve the actual argument passed by value/reference. 291 assert(isa<LLVM::LLVMPointerType>(arg.getType()) && 292 "Expected LLVM pointer type for argument with " 293 "`llvm.byval`/`llvm.byref` attribute"); 294 Type resTy = typeConverter.convertType( 295 cast<TypeAttr>(byValRefAttr->getValue()).getValue()); 296 297 auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg); 298 rewriter.replaceUsesOfBlockArgument(oldArg, valueArg); 299 } 300 } 301 302 FailureOr<LLVM::LLVMFuncOp> 303 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, 304 ConversionPatternRewriter &rewriter, 305 const LLVMTypeConverter &converter) { 306 // Check the funcOp has `FunctionType`. 307 auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType()); 308 if (!funcTy) 309 return rewriter.notifyMatchFailure( 310 funcOp, "Only support FunctionOpInterface with FunctionType"); 311 312 // Keep track of the entry block arguments. They will be needed later. 313 SmallVector<BlockArgument> oldBlockArgs = 314 llvm::to_vector(funcOp.getArguments()); 315 316 // Convert the original function arguments. They are converted using the 317 // LLVMTypeConverter provided to this legalization pattern. 318 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName); 319 // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was 320 // overriden with an LLVM pointer type for later processing. 321 SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs; 322 TypeConverter::SignatureConversion result(funcOp.getNumArguments()); 323 auto llvmType = converter.convertFunctionSignature( 324 funcOp, varargsAttr && varargsAttr.getValue(), 325 shouldUseBarePtrCallConv(funcOp, &converter), result, 326 byValRefNonPtrAttrs); 327 if (!llvmType) 328 return rewriter.notifyMatchFailure(funcOp, "signature conversion failed"); 329 330 // Create an LLVM function, use external linkage by default until MLIR 331 // functions have linkage. 332 LLVM::Linkage linkage = LLVM::Linkage::External; 333 if (funcOp->hasAttr(linkageAttrName)) { 334 auto attr = 335 dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName)); 336 if (!attr) { 337 funcOp->emitError() << "Contains " << linkageAttrName 338 << " attribute not of type LLVM::LinkageAttr"; 339 return rewriter.notifyMatchFailure( 340 funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr"); 341 } 342 linkage = attr.getLinkage(); 343 } 344 345 SmallVector<NamedAttribute, 4> attributes; 346 filterFuncAttributes(funcOp, attributes); 347 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 348 funcOp.getLoc(), funcOp.getName(), llvmType, linkage, 349 /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, 350 attributes); 351 cast<FunctionOpInterface>(newFuncOp.getOperation()) 352 .setVisibility(funcOp.getVisibility()); 353 354 // Create a memory effect attribute corresponding to readnone. 355 StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName(); 356 if (funcOp->hasAttr(readnoneAttrName)) { 357 auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName); 358 if (!attr) { 359 funcOp->emitError() << "Contains " << readnoneAttrName 360 << " attribute not of type UnitAttr"; 361 return rewriter.notifyMatchFailure( 362 funcOp, "Contains readnone attribute not of type UnitAttr"); 363 } 364 auto memoryAttr = LLVM::MemoryEffectsAttr::get( 365 rewriter.getContext(), 366 {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef, 367 LLVM::ModRefInfo::NoModRef}); 368 newFuncOp.setMemoryEffectsAttr(memoryAttr); 369 } 370 371 // Propagate argument/result attributes to all converted arguments/result 372 // obtained after converting a given original argument/result. 373 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { 374 assert(!resAttrDicts.empty() && "expected array to be non-empty"); 375 if (funcOp.getNumResults() == 1) 376 newFuncOp.setAllResultAttrs(resAttrDicts); 377 } 378 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { 379 SmallVector<Attribute> newArgAttrs( 380 cast<LLVM::LLVMFunctionType>(llvmType).getNumParams()); 381 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { 382 // Some LLVM IR attribute have a type attached to them. During FuncOp -> 383 // LLVMFuncOp conversion these types may have changed. Account for that 384 // change by converting attributes' types as well. 385 SmallVector<NamedAttribute, 4> convertedAttrs; 386 auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]); 387 convertedAttrs.reserve(attrsDict.size()); 388 for (const NamedAttribute &attr : attrsDict) { 389 const auto convert = [&](const NamedAttribute &attr) { 390 return TypeAttr::get(converter.convertType( 391 cast<TypeAttr>(attr.getValue()).getValue())); 392 }; 393 if (attr.getName().getValue() == 394 LLVM::LLVMDialect::getByValAttrName()) { 395 convertedAttrs.push_back(rewriter.getNamedAttr( 396 LLVM::LLVMDialect::getByValAttrName(), convert(attr))); 397 } else if (attr.getName().getValue() == 398 LLVM::LLVMDialect::getByRefAttrName()) { 399 convertedAttrs.push_back(rewriter.getNamedAttr( 400 LLVM::LLVMDialect::getByRefAttrName(), convert(attr))); 401 } else if (attr.getName().getValue() == 402 LLVM::LLVMDialect::getStructRetAttrName()) { 403 convertedAttrs.push_back(rewriter.getNamedAttr( 404 LLVM::LLVMDialect::getStructRetAttrName(), convert(attr))); 405 } else if (attr.getName().getValue() == 406 LLVM::LLVMDialect::getInAllocaAttrName()) { 407 convertedAttrs.push_back(rewriter.getNamedAttr( 408 LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr))); 409 } else { 410 convertedAttrs.push_back(attr); 411 } 412 } 413 auto mapping = result.getInputMapping(i); 414 assert(mapping && "unexpected deletion of function argument"); 415 // Only attach the new argument attributes if there is a one-to-one 416 // mapping from old to new types. Otherwise, attributes might be 417 // attached to types that they do not support. 418 if (mapping->size == 1) { 419 newArgAttrs[mapping->inputNo] = 420 DictionaryAttr::get(rewriter.getContext(), convertedAttrs); 421 continue; 422 } 423 // TODO: Implement custom handling for types that expand to multiple 424 // function arguments. 425 for (size_t j = 0; j < mapping->size; ++j) 426 newArgAttrs[mapping->inputNo + j] = 427 DictionaryAttr::get(rewriter.getContext(), {}); 428 } 429 if (!newArgAttrs.empty()) 430 newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs)); 431 } 432 433 rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(), 434 newFuncOp.end()); 435 // Convert just the entry block. The remaining unstructured control flow is 436 // converted by ControlFlowToLLVM. 437 if (!newFuncOp.getBody().empty()) 438 rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result, 439 &converter); 440 441 // Fix the type mismatch between the materialized `llvm.ptr` and the expected 442 // pointee type in the function body when converting `llvm.byval`/`llvm.byref` 443 // function arguments. 444 restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs, 445 oldBlockArgs, newFuncOp); 446 447 if (!shouldUseBarePtrCallConv(funcOp, &converter)) { 448 if (funcOp->getAttrOfType<UnitAttr>( 449 LLVM::LLVMDialect::getEmitCWrapperAttrName())) { 450 if (newFuncOp.isVarArg()) 451 return funcOp.emitError("C interface for variadic functions is not " 452 "supported yet."); 453 454 if (newFuncOp.isExternal()) 455 wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp, 456 newFuncOp); 457 else 458 wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp, 459 newFuncOp); 460 } 461 } 462 463 return newFuncOp; 464 } 465 466 namespace { 467 468 /// FuncOp legalization pattern that converts MemRef arguments to pointers to 469 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type 470 /// information. 471 struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> { 472 FuncOpConversion(const LLVMTypeConverter &converter) 473 : ConvertOpToLLVMPattern(converter) {} 474 475 LogicalResult 476 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, 477 ConversionPatternRewriter &rewriter) const override { 478 FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp( 479 cast<FunctionOpInterface>(funcOp.getOperation()), rewriter, 480 *getTypeConverter()); 481 if (failed(newFuncOp)) 482 return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop"); 483 484 rewriter.eraseOp(funcOp); 485 return success(); 486 } 487 }; 488 489 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> { 490 using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern; 491 492 LogicalResult 493 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor, 494 ConversionPatternRewriter &rewriter) const override { 495 auto type = typeConverter->convertType(op.getResult().getType()); 496 if (!type || !LLVM::isCompatibleType(type)) 497 return rewriter.notifyMatchFailure(op, "failed to convert result type"); 498 499 auto newOp = 500 rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue()); 501 for (const NamedAttribute &attr : op->getAttrs()) { 502 if (attr.getName().strref() == "value") 503 continue; 504 newOp->setAttr(attr.getName(), attr.getValue()); 505 } 506 rewriter.replaceOp(op, newOp->getResults()); 507 return success(); 508 } 509 }; 510 511 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and 512 // passes the pointer to the MemRef across function boundaries. 513 template <typename CallOpType> 514 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { 515 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern; 516 using Super = CallOpInterfaceLowering<CallOpType>; 517 using Base = ConvertOpToLLVMPattern<CallOpType>; 518 519 LogicalResult matchAndRewriteImpl(CallOpType callOp, 520 typename CallOpType::Adaptor adaptor, 521 ConversionPatternRewriter &rewriter, 522 bool useBarePtrCallConv = false) const { 523 // Pack the result types into a struct. 524 Type packedResult = nullptr; 525 unsigned numResults = callOp.getNumResults(); 526 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); 527 528 if (numResults != 0) { 529 if (!(packedResult = this->getTypeConverter()->packFunctionResults( 530 resultTypes, useBarePtrCallConv))) 531 return failure(); 532 } 533 534 if (useBarePtrCallConv) { 535 for (auto it : callOp->getOperands()) { 536 Type operandType = it.getType(); 537 if (isa<UnrankedMemRefType>(operandType)) { 538 // Unranked memref is not supported in the bare pointer calling 539 // convention. 540 return failure(); 541 } 542 } 543 } 544 auto promoted = this->getTypeConverter()->promoteOperands( 545 callOp.getLoc(), /*opOperands=*/callOp->getOperands(), 546 adaptor.getOperands(), rewriter, useBarePtrCallConv); 547 auto newOp = rewriter.create<LLVM::CallOp>( 548 callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), 549 promoted, callOp->getAttrs()); 550 551 newOp.getProperties().operandSegmentSizes = { 552 static_cast<int32_t>(promoted.size()), 0}; 553 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); 554 555 SmallVector<Value, 4> results; 556 if (numResults < 2) { 557 // If < 2 results, packing did not do anything and we can just return. 558 results.append(newOp.result_begin(), newOp.result_end()); 559 } else { 560 // Otherwise, it had been converted to an operation producing a structure. 561 // Extract individual results from the structure and return them as list. 562 results.reserve(numResults); 563 for (unsigned i = 0; i < numResults; ++i) { 564 results.push_back(rewriter.create<LLVM::ExtractValueOp>( 565 callOp.getLoc(), newOp->getResult(0), i)); 566 } 567 } 568 569 if (useBarePtrCallConv) { 570 // For the bare-ptr calling convention, promote memref results to 571 // descriptors. 572 assert(results.size() == resultTypes.size() && 573 "The number of arguments and types doesn't match"); 574 this->getTypeConverter()->promoteBarePtrsToDescriptors( 575 rewriter, callOp.getLoc(), resultTypes, results); 576 } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(), 577 resultTypes, results, 578 /*toDynamic=*/false))) { 579 return failure(); 580 } 581 582 rewriter.replaceOp(callOp, results); 583 return success(); 584 } 585 }; 586 587 class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> { 588 public: 589 CallOpLowering(const LLVMTypeConverter &typeConverter, 590 // Can be nullptr. 591 const SymbolTable *symbolTable, PatternBenefit benefit = 1) 592 : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit), 593 symbolTable(symbolTable) {} 594 595 LogicalResult 596 matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, 597 ConversionPatternRewriter &rewriter) const override { 598 bool useBarePtrCallConv = false; 599 if (getTypeConverter()->getOptions().useBarePtrCallConv) { 600 useBarePtrCallConv = true; 601 } else if (symbolTable != nullptr) { 602 // Fast lookup. 603 Operation *callee = 604 symbolTable->lookup(callOp.getCalleeAttr().getValue()); 605 useBarePtrCallConv = 606 callee != nullptr && callee->hasAttr(barePtrAttrName); 607 } else { 608 // Warning: This is a linear lookup. 609 Operation *callee = 610 SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr()); 611 useBarePtrCallConv = 612 callee != nullptr && callee->hasAttr(barePtrAttrName); 613 } 614 return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv); 615 } 616 617 private: 618 const SymbolTable *symbolTable = nullptr; 619 }; 620 621 struct CallIndirectOpLowering 622 : public CallOpInterfaceLowering<func::CallIndirectOp> { 623 using Super::Super; 624 625 LogicalResult 626 matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor, 627 ConversionPatternRewriter &rewriter) const override { 628 return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter); 629 } 630 }; 631 632 struct UnrealizedConversionCastOpLowering 633 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> { 634 using ConvertOpToLLVMPattern< 635 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern; 636 637 LogicalResult 638 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, 639 ConversionPatternRewriter &rewriter) const override { 640 SmallVector<Type> convertedTypes; 641 if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(), 642 convertedTypes)) && 643 convertedTypes == adaptor.getInputs().getTypes()) { 644 rewriter.replaceOp(op, adaptor.getInputs()); 645 return success(); 646 } 647 648 convertedTypes.clear(); 649 if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(), 650 convertedTypes)) && 651 convertedTypes == op.getOutputs().getType()) { 652 rewriter.replaceOp(op, adaptor.getInputs()); 653 return success(); 654 } 655 return failure(); 656 } 657 }; 658 659 // Special lowering pattern for `ReturnOps`. Unlike all other operations, 660 // `ReturnOp` interacts with the function signature and must have as many 661 // operands as the function has return values. Because in LLVM IR, functions 662 // can only return 0 or 1 value, we pack multiple values into a structure type. 663 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if 664 // necessary before returning it 665 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> { 666 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern; 667 668 LogicalResult 669 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 670 ConversionPatternRewriter &rewriter) const override { 671 Location loc = op.getLoc(); 672 unsigned numArguments = op.getNumOperands(); 673 SmallVector<Value, 4> updatedOperands; 674 675 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); 676 bool useBarePtrCallConv = 677 shouldUseBarePtrCallConv(funcOp, this->getTypeConverter()); 678 if (useBarePtrCallConv) { 679 // For the bare-ptr calling convention, extract the aligned pointer to 680 // be returned from the memref descriptor. 681 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { 682 Type oldTy = std::get<0>(it).getType(); 683 Value newOperand = std::get<1>(it); 684 if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr( 685 cast<BaseMemRefType>(oldTy))) { 686 MemRefDescriptor memrefDesc(newOperand); 687 newOperand = memrefDesc.allocatedPtr(rewriter, loc); 688 } else if (isa<UnrankedMemRefType>(oldTy)) { 689 // Unranked memref is not supported in the bare pointer calling 690 // convention. 691 return failure(); 692 } 693 updatedOperands.push_back(newOperand); 694 } 695 } else { 696 updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); 697 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), 698 updatedOperands, 699 /*toDynamic=*/true); 700 } 701 702 // If ReturnOp has 0 or 1 operand, create it and return immediately. 703 if (numArguments <= 1) { 704 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( 705 op, TypeRange(), updatedOperands, op->getAttrs()); 706 return success(); 707 } 708 709 // Otherwise, we need to pack the arguments into an LLVM struct type before 710 // returning. 711 auto packedType = getTypeConverter()->packFunctionResults( 712 op.getOperandTypes(), useBarePtrCallConv); 713 if (!packedType) { 714 return rewriter.notifyMatchFailure(op, "could not convert result types"); 715 } 716 717 Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType); 718 for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { 719 packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx); 720 } 721 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, 722 op->getAttrs()); 723 return success(); 724 } 725 }; 726 } // namespace 727 728 void mlir::populateFuncToLLVMFuncOpConversionPattern( 729 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 730 patterns.add<FuncOpConversion>(converter); 731 } 732 733 void mlir::populateFuncToLLVMConversionPatterns( 734 const LLVMTypeConverter &converter, RewritePatternSet &patterns, 735 const SymbolTable *symbolTable) { 736 populateFuncToLLVMFuncOpConversionPattern(converter, patterns); 737 patterns.add<CallIndirectOpLowering>(converter); 738 patterns.add<CallOpLowering>(converter, symbolTable); 739 patterns.add<ConstantOpLowering>(converter); 740 patterns.add<ReturnOpLowering>(converter); 741 } 742 743 namespace { 744 /// A pass converting Func operations into the LLVM IR dialect. 745 struct ConvertFuncToLLVMPass 746 : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> { 747 using Base::Base; 748 749 /// Run the dialect converter on the module. 750 void runOnOperation() override { 751 ModuleOp m = getOperation(); 752 StringRef dataLayout; 753 auto dataLayoutAttr = dyn_cast_or_null<StringAttr>( 754 m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName())); 755 if (dataLayoutAttr) 756 dataLayout = dataLayoutAttr.getValue(); 757 758 if (failed(LLVM::LLVMDialect::verifyDataLayoutString( 759 dataLayout, [this](const Twine &message) { 760 getOperation().emitError() << message.str(); 761 }))) { 762 signalPassFailure(); 763 return; 764 } 765 766 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 767 768 LowerToLLVMOptions options(&getContext(), 769 dataLayoutAnalysis.getAtOrAbove(m)); 770 options.useBarePtrCallConv = useBarePtrCallConv; 771 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 772 options.overrideIndexBitwidth(indexBitwidth); 773 options.dataLayout = llvm::DataLayout(dataLayout); 774 775 LLVMTypeConverter typeConverter(&getContext(), options, 776 &dataLayoutAnalysis); 777 778 std::optional<SymbolTable> optSymbolTable = std::nullopt; 779 const SymbolTable *symbolTable = nullptr; 780 if (!options.useBarePtrCallConv) { 781 optSymbolTable.emplace(m); 782 symbolTable = &optSymbolTable.value(); 783 } 784 785 RewritePatternSet patterns(&getContext()); 786 populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable); 787 788 LLVMConversionTarget target(getContext()); 789 if (failed(applyPartialConversion(m, target, std::move(patterns)))) 790 signalPassFailure(); 791 } 792 }; 793 794 struct SetLLVMModuleDataLayoutPass 795 : public impl::SetLLVMModuleDataLayoutPassBase< 796 SetLLVMModuleDataLayoutPass> { 797 using Base::Base; 798 799 /// Run the dialect converter on the module. 800 void runOnOperation() override { 801 if (failed(LLVM::LLVMDialect::verifyDataLayoutString( 802 this->dataLayout, [this](const Twine &message) { 803 getOperation().emitError() << message.str(); 804 }))) { 805 signalPassFailure(); 806 return; 807 } 808 ModuleOp m = getOperation(); 809 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), 810 StringAttr::get(m.getContext(), this->dataLayout)); 811 } 812 }; 813 } // namespace 814 815 //===----------------------------------------------------------------------===// 816 // ConvertToLLVMPatternInterface implementation 817 //===----------------------------------------------------------------------===// 818 819 namespace { 820 /// Implement the interface to convert Func to LLVM. 821 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface { 822 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 823 /// Hook for derived dialect interface to provide conversion patterns 824 /// and mark dialect legal for the conversion target. 825 void populateConvertToLLVMConversionPatterns( 826 ConversionTarget &target, LLVMTypeConverter &typeConverter, 827 RewritePatternSet &patterns) const final { 828 populateFuncToLLVMConversionPatterns(typeConverter, patterns); 829 } 830 }; 831 } // namespace 832 833 void mlir::registerConvertFuncToLLVMInterface(DialectRegistry ®istry) { 834 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 835 dialect->addInterfaces<FuncToLLVMDialectInterface>(); 836 }); 837 } 838