1 //===- TypeFromLLVM.cpp - type translation from LLVM to MLIR IR -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Target/LLVMIR/TypeFromLLVM.h" 10 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/MLIRContext.h" 13 14 #include "llvm/ADT/TypeSwitch.h" 15 #include "llvm/IR/DataLayout.h" 16 #include "llvm/IR/DerivedTypes.h" 17 #include "llvm/IR/Type.h" 18 19 using namespace mlir; 20 21 namespace mlir { 22 namespace LLVM { 23 namespace detail { 24 /// Support for translating LLVM IR types to MLIR LLVM dialect types. 25 class TypeFromLLVMIRTranslatorImpl { 26 public: 27 /// Constructs a class creating types in the given MLIR context. 28 TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {} 29 30 /// Translates the given type. 31 Type translateType(llvm::Type *type) { 32 if (knownTranslations.count(type)) 33 return knownTranslations.lookup(type); 34 35 Type translated = 36 llvm::TypeSwitch<llvm::Type *, Type>(type) 37 .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType, 38 llvm::PointerType, llvm::StructType, llvm::FixedVectorType, 39 llvm::ScalableVectorType, llvm::TargetExtType>( 40 [this](auto *type) { return this->translate(type); }) 41 .Default([this](llvm::Type *type) { 42 return translatePrimitiveType(type); 43 }); 44 knownTranslations.try_emplace(type, translated); 45 return translated; 46 } 47 48 private: 49 /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, 50 /// type. 51 Type translatePrimitiveType(llvm::Type *type) { 52 if (type->isVoidTy()) 53 return LLVM::LLVMVoidType::get(&context); 54 if (type->isHalfTy()) 55 return Float16Type::get(&context); 56 if (type->isBFloatTy()) 57 return BFloat16Type::get(&context); 58 if (type->isFloatTy()) 59 return Float32Type::get(&context); 60 if (type->isDoubleTy()) 61 return Float64Type::get(&context); 62 if (type->isFP128Ty()) 63 return Float128Type::get(&context); 64 if (type->isX86_FP80Ty()) 65 return Float80Type::get(&context); 66 if (type->isX86_AMXTy()) 67 return LLVM::LLVMX86AMXType::get(&context); 68 if (type->isPPC_FP128Ty()) 69 return LLVM::LLVMPPCFP128Type::get(&context); 70 if (type->isLabelTy()) 71 return LLVM::LLVMLabelType::get(&context); 72 if (type->isMetadataTy()) 73 return LLVM::LLVMMetadataType::get(&context); 74 if (type->isTokenTy()) 75 return LLVM::LLVMTokenType::get(&context); 76 llvm_unreachable("not a primitive type"); 77 } 78 79 /// Translates the given array type. 80 Type translate(llvm::ArrayType *type) { 81 return LLVM::LLVMArrayType::get(translateType(type->getElementType()), 82 type->getNumElements()); 83 } 84 85 /// Translates the given function type. 86 Type translate(llvm::FunctionType *type) { 87 SmallVector<Type, 8> paramTypes; 88 translateTypes(type->params(), paramTypes); 89 return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), 90 paramTypes, type->isVarArg()); 91 } 92 93 /// Translates the given integer type. 94 Type translate(llvm::IntegerType *type) { 95 return IntegerType::get(&context, type->getBitWidth()); 96 } 97 98 /// Translates the given pointer type. 99 Type translate(llvm::PointerType *type) { 100 return LLVM::LLVMPointerType::get(&context, type->getAddressSpace()); 101 } 102 103 /// Translates the given structure type. 104 Type translate(llvm::StructType *type) { 105 SmallVector<Type, 8> subtypes; 106 if (type->isLiteral()) { 107 translateTypes(type->subtypes(), subtypes); 108 return LLVM::LLVMStructType::getLiteral(&context, subtypes, 109 type->isPacked()); 110 } 111 112 if (type->isOpaque()) 113 return LLVM::LLVMStructType::getOpaque(type->getName(), &context); 114 115 // With opaque pointers, types in LLVM can't be recursive anymore. Note that 116 // using getIdentified is not possible, as type names in LLVM are not 117 // guaranteed to be unique. 118 translateTypes(type->subtypes(), subtypes); 119 LLVM::LLVMStructType translated = LLVM::LLVMStructType::getNewIdentified( 120 &context, type->getName(), subtypes, type->isPacked()); 121 knownTranslations.try_emplace(type, translated); 122 return translated; 123 } 124 125 /// Translates the given fixed-vector type. 126 Type translate(llvm::FixedVectorType *type) { 127 return LLVM::getFixedVectorType(translateType(type->getElementType()), 128 type->getNumElements()); 129 } 130 131 /// Translates the given scalable-vector type. 132 Type translate(llvm::ScalableVectorType *type) { 133 return LLVM::LLVMScalableVectorType::get( 134 translateType(type->getElementType()), type->getMinNumElements()); 135 } 136 137 /// Translates the given target extension type. 138 Type translate(llvm::TargetExtType *type) { 139 SmallVector<Type> typeParams; 140 translateTypes(type->type_params(), typeParams); 141 142 return LLVM::LLVMTargetExtType::get(&context, type->getName(), typeParams, 143 type->int_params()); 144 } 145 146 /// Translates a list of types. 147 void translateTypes(ArrayRef<llvm::Type *> types, 148 SmallVectorImpl<Type> &result) { 149 result.reserve(result.size() + types.size()); 150 for (llvm::Type *type : types) 151 result.push_back(translateType(type)); 152 } 153 154 /// Map of known translations. Serves as a cache and as recursion stopper for 155 /// translating recursive structs. 156 llvm::DenseMap<llvm::Type *, Type> knownTranslations; 157 158 /// The context in which MLIR types are created. 159 MLIRContext &context; 160 }; 161 162 } // namespace detail 163 } // namespace LLVM 164 } // namespace mlir 165 166 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context) 167 : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {} 168 169 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() = default; 170 171 Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) { 172 return impl->translateType(type); 173 } 174