xref: /llvm-project/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp (revision 2f743ac52e945e155ff3cb1f8ca5287b306b831e)
1 //===- TypeFromLLVM.cpp - type translation from LLVM to MLIR IR -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Target/LLVMIR/TypeFromLLVM.h"
10 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/MLIRContext.h"
13 
14 #include "llvm/ADT/TypeSwitch.h"
15 #include "llvm/IR/DataLayout.h"
16 #include "llvm/IR/DerivedTypes.h"
17 #include "llvm/IR/Type.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 namespace LLVM {
23 namespace detail {
24 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
25 class TypeFromLLVMIRTranslatorImpl {
26 public:
27   /// Constructs a class creating types in the given MLIR context.
28   TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
29 
30   /// Translates the given type.
31   Type translateType(llvm::Type *type) {
32     if (knownTranslations.count(type))
33       return knownTranslations.lookup(type);
34 
35     Type translated =
36         llvm::TypeSwitch<llvm::Type *, Type>(type)
37             .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
38                   llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
39                   llvm::ScalableVectorType, llvm::TargetExtType>(
40                 [this](auto *type) { return this->translate(type); })
41             .Default([this](llvm::Type *type) {
42               return translatePrimitiveType(type);
43             });
44     knownTranslations.try_emplace(type, translated);
45     return translated;
46   }
47 
48 private:
49   /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
50   /// type.
51   Type translatePrimitiveType(llvm::Type *type) {
52     if (type->isVoidTy())
53       return LLVM::LLVMVoidType::get(&context);
54     if (type->isHalfTy())
55       return Float16Type::get(&context);
56     if (type->isBFloatTy())
57       return BFloat16Type::get(&context);
58     if (type->isFloatTy())
59       return Float32Type::get(&context);
60     if (type->isDoubleTy())
61       return Float64Type::get(&context);
62     if (type->isFP128Ty())
63       return Float128Type::get(&context);
64     if (type->isX86_FP80Ty())
65       return Float80Type::get(&context);
66     if (type->isX86_AMXTy())
67       return LLVM::LLVMX86AMXType::get(&context);
68     if (type->isPPC_FP128Ty())
69       return LLVM::LLVMPPCFP128Type::get(&context);
70     if (type->isLabelTy())
71       return LLVM::LLVMLabelType::get(&context);
72     if (type->isMetadataTy())
73       return LLVM::LLVMMetadataType::get(&context);
74     if (type->isTokenTy())
75       return LLVM::LLVMTokenType::get(&context);
76     llvm_unreachable("not a primitive type");
77   }
78 
79   /// Translates the given array type.
80   Type translate(llvm::ArrayType *type) {
81     return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
82                                     type->getNumElements());
83   }
84 
85   /// Translates the given function type.
86   Type translate(llvm::FunctionType *type) {
87     SmallVector<Type, 8> paramTypes;
88     translateTypes(type->params(), paramTypes);
89     return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
90                                        paramTypes, type->isVarArg());
91   }
92 
93   /// Translates the given integer type.
94   Type translate(llvm::IntegerType *type) {
95     return IntegerType::get(&context, type->getBitWidth());
96   }
97 
98   /// Translates the given pointer type.
99   Type translate(llvm::PointerType *type) {
100     return LLVM::LLVMPointerType::get(&context, type->getAddressSpace());
101   }
102 
103   /// Translates the given structure type.
104   Type translate(llvm::StructType *type) {
105     SmallVector<Type, 8> subtypes;
106     if (type->isLiteral()) {
107       translateTypes(type->subtypes(), subtypes);
108       return LLVM::LLVMStructType::getLiteral(&context, subtypes,
109                                               type->isPacked());
110     }
111 
112     if (type->isOpaque())
113       return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
114 
115     // With opaque pointers, types in LLVM can't be recursive anymore. Note that
116     // using getIdentified is not possible, as type names in LLVM are not
117     // guaranteed to be unique.
118     translateTypes(type->subtypes(), subtypes);
119     LLVM::LLVMStructType translated = LLVM::LLVMStructType::getNewIdentified(
120         &context, type->getName(), subtypes, type->isPacked());
121     knownTranslations.try_emplace(type, translated);
122     return translated;
123   }
124 
125   /// Translates the given fixed-vector type.
126   Type translate(llvm::FixedVectorType *type) {
127     return LLVM::getFixedVectorType(translateType(type->getElementType()),
128                                     type->getNumElements());
129   }
130 
131   /// Translates the given scalable-vector type.
132   Type translate(llvm::ScalableVectorType *type) {
133     return LLVM::LLVMScalableVectorType::get(
134         translateType(type->getElementType()), type->getMinNumElements());
135   }
136 
137   /// Translates the given target extension type.
138   Type translate(llvm::TargetExtType *type) {
139     SmallVector<Type> typeParams;
140     translateTypes(type->type_params(), typeParams);
141 
142     return LLVM::LLVMTargetExtType::get(&context, type->getName(), typeParams,
143                                         type->int_params());
144   }
145 
146   /// Translates a list of types.
147   void translateTypes(ArrayRef<llvm::Type *> types,
148                       SmallVectorImpl<Type> &result) {
149     result.reserve(result.size() + types.size());
150     for (llvm::Type *type : types)
151       result.push_back(translateType(type));
152   }
153 
154   /// Map of known translations. Serves as a cache and as recursion stopper for
155   /// translating recursive structs.
156   llvm::DenseMap<llvm::Type *, Type> knownTranslations;
157 
158   /// The context in which MLIR types are created.
159   MLIRContext &context;
160 };
161 
162 } // namespace detail
163 } // namespace LLVM
164 } // namespace mlir
165 
166 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
167     : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
168 
169 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() = default;
170 
171 Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
172   return impl->translateType(type);
173 }
174