1 //===- TypeToLLVM.cpp - type translation from MLIR to LLVM IR -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 10 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/MLIRContext.h" 13 14 #include "llvm/ADT/TypeSwitch.h" 15 #include "llvm/IR/DataLayout.h" 16 #include "llvm/IR/DerivedTypes.h" 17 #include "llvm/IR/Type.h" 18 19 using namespace mlir; 20 21 namespace mlir { 22 namespace LLVM { 23 namespace detail { 24 /// Support for translating MLIR LLVM dialect types to LLVM IR. 25 class TypeToLLVMIRTranslatorImpl { 26 public: 27 /// Constructs a class creating types in the given LLVM context. 28 TypeToLLVMIRTranslatorImpl(llvm::LLVMContext &context) : context(context) {} 29 30 /// Translates a single type. 31 llvm::Type *translateType(Type type) { 32 // If the conversion is already known, just return it. 33 if (knownTranslations.count(type)) 34 return knownTranslations.lookup(type); 35 36 // Dispatch to an appropriate function. 37 llvm::Type *translated = 38 llvm::TypeSwitch<Type, llvm::Type *>(type) 39 .Case([this](LLVM::LLVMVoidType) { 40 return llvm::Type::getVoidTy(context); 41 }) 42 .Case( 43 [this](Float16Type) { return llvm::Type::getHalfTy(context); }) 44 .Case([this](BFloat16Type) { 45 return llvm::Type::getBFloatTy(context); 46 }) 47 .Case( 48 [this](Float32Type) { return llvm::Type::getFloatTy(context); }) 49 .Case([this](Float64Type) { 50 return llvm::Type::getDoubleTy(context); 51 }) 52 .Case([this](Float80Type) { 53 return llvm::Type::getX86_FP80Ty(context); 54 }) 55 .Case([this](Float128Type) { 56 return llvm::Type::getFP128Ty(context); 57 }) 58 .Case([this](LLVM::LLVMPPCFP128Type) { 59 return llvm::Type::getPPC_FP128Ty(context); 60 }) 61 .Case([this](LLVM::LLVMTokenType) { 62 return llvm::Type::getTokenTy(context); 63 }) 64 .Case([this](LLVM::LLVMLabelType) { 65 return llvm::Type::getLabelTy(context); 66 }) 67 .Case([this](LLVM::LLVMMetadataType) { 68 return llvm::Type::getMetadataTy(context); 69 }) 70 .Case([this](LLVM::LLVMX86AMXType) { 71 return llvm::Type::getX86_AMXTy(context); 72 }) 73 .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType, 74 LLVM::LLVMPointerType, LLVM::LLVMStructType, 75 LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType, 76 VectorType, LLVM::LLVMTargetExtType>( 77 [this](auto type) { return this->translate(type); }) 78 .Default([](Type t) -> llvm::Type * { 79 llvm_unreachable("unknown LLVM dialect type"); 80 }); 81 82 // Cache the result of the conversion and return. 83 knownTranslations.try_emplace(type, translated); 84 return translated; 85 } 86 87 private: 88 /// Translates the given array type. 89 llvm::Type *translate(LLVM::LLVMArrayType type) { 90 return llvm::ArrayType::get(translateType(type.getElementType()), 91 type.getNumElements()); 92 } 93 94 /// Translates the given function type. 95 llvm::Type *translate(LLVM::LLVMFunctionType type) { 96 SmallVector<llvm::Type *, 8> paramTypes; 97 translateTypes(type.getParams(), paramTypes); 98 return llvm::FunctionType::get(translateType(type.getReturnType()), 99 paramTypes, type.isVarArg()); 100 } 101 102 /// Translates the given integer type. 103 llvm::Type *translate(IntegerType type) { 104 return llvm::IntegerType::get(context, type.getWidth()); 105 } 106 107 /// Translates the given pointer type. 108 llvm::Type *translate(LLVM::LLVMPointerType type) { 109 return llvm::PointerType::get(context, type.getAddressSpace()); 110 } 111 112 /// Translates the given structure type, supports both identified and literal 113 /// structs. This will _create_ a new identified structure every time, use 114 /// `convertType` if a structure with the same name must be looked up instead. 115 llvm::Type *translate(LLVM::LLVMStructType type) { 116 SmallVector<llvm::Type *, 8> subtypes; 117 if (!type.isIdentified()) { 118 translateTypes(type.getBody(), subtypes); 119 return llvm::StructType::get(context, subtypes, type.isPacked()); 120 } 121 122 llvm::StructType *structType = 123 llvm::StructType::create(context, type.getName()); 124 // Mark the type we just created as known so that recursive calls can pick 125 // it up and use directly. 126 knownTranslations.try_emplace(type, structType); 127 if (type.isOpaque()) 128 return structType; 129 130 translateTypes(type.getBody(), subtypes); 131 structType->setBody(subtypes, type.isPacked()); 132 return structType; 133 } 134 135 /// Translates the given built-in vector type compatible with LLVM. 136 llvm::Type *translate(VectorType type) { 137 assert(LLVM::isCompatibleVectorType(type) && 138 "expected compatible with LLVM vector type"); 139 if (type.isScalable()) 140 return llvm::ScalableVectorType::get(translateType(type.getElementType()), 141 type.getNumElements()); 142 return llvm::FixedVectorType::get(translateType(type.getElementType()), 143 type.getNumElements()); 144 } 145 146 /// Translates the given fixed-vector type. 147 llvm::Type *translate(LLVM::LLVMFixedVectorType type) { 148 return llvm::FixedVectorType::get(translateType(type.getElementType()), 149 type.getNumElements()); 150 } 151 152 /// Translates the given scalable-vector type. 153 llvm::Type *translate(LLVM::LLVMScalableVectorType type) { 154 return llvm::ScalableVectorType::get(translateType(type.getElementType()), 155 type.getMinNumElements()); 156 } 157 158 /// Translates the given target extension type. 159 llvm::Type *translate(LLVM::LLVMTargetExtType type) { 160 SmallVector<llvm::Type *> typeParams; 161 translateTypes(type.getTypeParams(), typeParams); 162 return llvm::TargetExtType::get(context, type.getExtTypeName(), typeParams, 163 type.getIntParams()); 164 } 165 166 /// Translates a list of types. 167 void translateTypes(ArrayRef<Type> types, 168 SmallVectorImpl<llvm::Type *> &result) { 169 result.reserve(result.size() + types.size()); 170 for (auto type : types) 171 result.push_back(translateType(type)); 172 } 173 174 /// Reference to the context in which the LLVM IR types are created. 175 llvm::LLVMContext &context; 176 177 /// Map of known translation. This serves a double purpose: caches translation 178 /// results to avoid repeated recursive calls and makes sure identified 179 /// structs with the same name (that is, equal) are resolved to an existing 180 /// type instead of creating a new type. 181 llvm::DenseMap<Type, llvm::Type *> knownTranslations; 182 }; 183 } // namespace detail 184 } // namespace LLVM 185 } // namespace mlir 186 187 LLVM::TypeToLLVMIRTranslator::TypeToLLVMIRTranslator(llvm::LLVMContext &context) 188 : impl(new detail::TypeToLLVMIRTranslatorImpl(context)) {} 189 190 LLVM::TypeToLLVMIRTranslator::~TypeToLLVMIRTranslator() = default; 191 192 llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(Type type) { 193 return impl->translateType(type); 194 } 195 196 unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment( 197 Type type, const llvm::DataLayout &layout) { 198 return layout.getPrefTypeAlign(translateType(type)).value(); 199 } 200