xref: /llvm-project/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (revision 830b9b072d8458ee89c48f00d4de59456c9f467f)
1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the translation between an MLIR LLVM dialect module and
10 // the corresponding LLVMIR module. It only handles core LLVM IR operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
15 
16 #include "AttrKindDetail.h"
17 #include "DebugTranslation.h"
18 #include "LoopAnnotationTranslation.h"
19 #include "mlir/Dialect/DLTI/DLTI.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
22 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
23 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
24 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
25 #include "mlir/IR/AttrTypeSubElements.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/RegionGraphTraits.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Support/LogicalResult.h"
32 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
33 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
34 
35 #include "llvm/ADT/PostOrderIterator.h"
36 #include "llvm/ADT/SetVector.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
39 #include "llvm/IR/BasicBlock.h"
40 #include "llvm/IR/CFG.h"
41 #include "llvm/IR/Constants.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/InlineAsm.h"
45 #include "llvm/IR/IntrinsicsNVPTX.h"
46 #include "llvm/IR/LLVMContext.h"
47 #include "llvm/IR/MDBuilder.h"
48 #include "llvm/IR/Module.h"
49 #include "llvm/IR/Verifier.h"
50 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
51 #include "llvm/Transforms/Utils/Cloning.h"
52 #include "llvm/Transforms/Utils/ModuleUtils.h"
53 #include <optional>
54 
55 using namespace mlir;
56 using namespace mlir::LLVM;
57 using namespace mlir::LLVM::detail;
58 
59 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
60 
61 /// Translates the given data layout spec attribute to the LLVM IR data layout.
62 /// Only integer, float, pointer and endianness entries are currently supported.
63 static FailureOr<llvm::DataLayout>
64 translateDataLayout(DataLayoutSpecInterface attribute,
65                     const DataLayout &dataLayout,
66                     std::optional<Location> loc = std::nullopt) {
67   if (!loc)
68     loc = UnknownLoc::get(attribute.getContext());
69 
70   // Translate the endianness attribute.
71   std::string llvmDataLayout;
72   llvm::raw_string_ostream layoutStream(llvmDataLayout);
73   for (DataLayoutEntryInterface entry : attribute.getEntries()) {
74     auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey());
75     if (!key)
76       continue;
77     if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
78       auto value = cast<StringAttr>(entry.getValue());
79       bool isLittleEndian =
80           value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle;
81       layoutStream << "-" << (isLittleEndian ? "e" : "E");
82       layoutStream.flush();
83       continue;
84     }
85     if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) {
86       auto value = cast<IntegerAttr>(entry.getValue());
87       uint64_t space = value.getValue().getZExtValue();
88       // Skip the default address space.
89       if (space == 0)
90         continue;
91       layoutStream << "-A" << space;
92       layoutStream.flush();
93       continue;
94     }
95     if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) {
96       auto value = cast<IntegerAttr>(entry.getValue());
97       uint64_t alignment = value.getValue().getZExtValue();
98       // Skip the default stack alignment.
99       if (alignment == 0)
100         continue;
101       layoutStream << "-S" << alignment;
102       layoutStream.flush();
103       continue;
104     }
105     emitError(*loc) << "unsupported data layout key " << key;
106     return failure();
107   }
108 
109   // Go through the list of entries to check which types are explicitly
110   // specified in entries. Where possible, data layout queries are used instead
111   // of directly inspecting the entries.
112   for (DataLayoutEntryInterface entry : attribute.getEntries()) {
113     auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
114     if (!type)
115       continue;
116     // Data layout for the index type is irrelevant at this point.
117     if (isa<IndexType>(type))
118       continue;
119     layoutStream << "-";
120     LogicalResult result =
121         llvm::TypeSwitch<Type, LogicalResult>(type)
122             .Case<IntegerType, Float16Type, Float32Type, Float64Type,
123                   Float80Type, Float128Type>([&](Type type) -> LogicalResult {
124               if (auto intType = dyn_cast<IntegerType>(type)) {
125                 if (intType.getSignedness() != IntegerType::Signless)
126                   return emitError(*loc)
127                          << "unsupported data layout for non-signless integer "
128                          << intType;
129                 layoutStream << "i";
130               } else {
131                 layoutStream << "f";
132               }
133               unsigned size = dataLayout.getTypeSizeInBits(type);
134               unsigned abi = dataLayout.getTypeABIAlignment(type) * 8u;
135               unsigned preferred =
136                   dataLayout.getTypePreferredAlignment(type) * 8u;
137               layoutStream << size << ":" << abi;
138               if (abi != preferred)
139                 layoutStream << ":" << preferred;
140               return success();
141             })
142             .Case([&](LLVMPointerType ptrType) {
143               layoutStream << "p" << ptrType.getAddressSpace() << ":";
144               unsigned size = dataLayout.getTypeSizeInBits(type);
145               unsigned abi = dataLayout.getTypeABIAlignment(type) * 8u;
146               unsigned preferred =
147                   dataLayout.getTypePreferredAlignment(type) * 8u;
148               layoutStream << size << ":" << abi << ":" << preferred;
149               if (std::optional<unsigned> index = extractPointerSpecValue(
150                       entry.getValue(), PtrDLEntryPos::Index))
151                 layoutStream << ":" << *index;
152               return success();
153             })
154             .Default([loc](Type type) {
155               return emitError(*loc)
156                      << "unsupported type in data layout: " << type;
157             });
158     if (failed(result))
159       return failure();
160   }
161   layoutStream.flush();
162   StringRef layoutSpec(llvmDataLayout);
163   if (layoutSpec.startswith("-"))
164     layoutSpec = layoutSpec.drop_front();
165 
166   return llvm::DataLayout(layoutSpec);
167 }
168 
169 /// Builds a constant of a sequential LLVM type `type`, potentially containing
170 /// other sequential types recursively, from the individual constant values
171 /// provided in `constants`. `shape` contains the number of elements in nested
172 /// sequential types. Reports errors at `loc` and returns nullptr on error.
173 static llvm::Constant *
174 buildSequentialConstant(ArrayRef<llvm::Constant *> &constants,
175                         ArrayRef<int64_t> shape, llvm::Type *type,
176                         Location loc) {
177   if (shape.empty()) {
178     llvm::Constant *result = constants.front();
179     constants = constants.drop_front();
180     return result;
181   }
182 
183   llvm::Type *elementType;
184   if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
185     elementType = arrayTy->getElementType();
186   } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
187     elementType = vectorTy->getElementType();
188   } else {
189     emitError(loc) << "expected sequential LLVM types wrapping a scalar";
190     return nullptr;
191   }
192 
193   SmallVector<llvm::Constant *, 8> nested;
194   nested.reserve(shape.front());
195   for (int64_t i = 0; i < shape.front(); ++i) {
196     nested.push_back(buildSequentialConstant(constants, shape.drop_front(),
197                                              elementType, loc));
198     if (!nested.back())
199       return nullptr;
200   }
201 
202   if (shape.size() == 1 && type->isVectorTy())
203     return llvm::ConstantVector::get(nested);
204   return llvm::ConstantArray::get(
205       llvm::ArrayType::get(elementType, shape.front()), nested);
206 }
207 
208 /// Returns the first non-sequential type nested in sequential types.
209 static llvm::Type *getInnermostElementType(llvm::Type *type) {
210   do {
211     if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
212       type = arrayTy->getElementType();
213     } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
214       type = vectorTy->getElementType();
215     } else {
216       return type;
217     }
218   } while (true);
219 }
220 
221 /// Convert a dense elements attribute to an LLVM IR constant using its raw data
222 /// storage if possible. This supports elements attributes of tensor or vector
223 /// type and avoids constructing separate objects for individual values of the
224 /// innermost dimension. Constants for other dimensions are still constructed
225 /// recursively. Returns null if constructing from raw data is not supported for
226 /// this type, e.g., element type is not a power-of-two-sized primitive. Reports
227 /// other errors at `loc`.
228 static llvm::Constant *
229 convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
230                          llvm::Type *llvmType,
231                          const ModuleTranslation &moduleTranslation) {
232   if (!denseElementsAttr)
233     return nullptr;
234 
235   llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
236   if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType))
237     return nullptr;
238 
239   ShapedType type = denseElementsAttr.getType();
240   if (type.getNumElements() == 0)
241     return nullptr;
242 
243   // Check that the raw data size matches what is expected for the scalar size.
244   // TODO: in theory, we could repack the data here to keep constructing from
245   // raw data.
246   // TODO: we may also need to consider endianness when cross-compiling to an
247   // architecture where it is different.
248   unsigned elementByteSize = denseElementsAttr.getRawData().size() /
249                              denseElementsAttr.getNumElements();
250   if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits())
251     return nullptr;
252 
253   // Compute the shape of all dimensions but the innermost. Note that the
254   // innermost dimension may be that of the vector element type.
255   bool hasVectorElementType = isa<VectorType>(type.getElementType());
256   unsigned numAggregates =
257       denseElementsAttr.getNumElements() /
258       (hasVectorElementType ? 1
259                             : denseElementsAttr.getType().getShape().back());
260   ArrayRef<int64_t> outerShape = type.getShape();
261   if (!hasVectorElementType)
262     outerShape = outerShape.drop_back();
263 
264   // Handle the case of vector splat, LLVM has special support for it.
265   if (denseElementsAttr.isSplat() &&
266       (isa<VectorType>(type) || hasVectorElementType)) {
267     llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
268         innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
269         moduleTranslation);
270     llvm::Constant *splatVector =
271         llvm::ConstantDataVector::getSplat(0, splatValue);
272     SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
273     ArrayRef<llvm::Constant *> constantsRef = constants;
274     return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
275   }
276   if (denseElementsAttr.isSplat())
277     return nullptr;
278 
279   // In case of non-splat, create a constructor for the innermost constant from
280   // a piece of raw data.
281   std::function<llvm::Constant *(StringRef)> buildCstData;
282   if (isa<TensorType>(type)) {
283     auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
284     if (vectorElementType && vectorElementType.getRank() == 1) {
285       buildCstData = [&](StringRef data) {
286         return llvm::ConstantDataVector::getRaw(
287             data, vectorElementType.getShape().back(), innermostLLVMType);
288       };
289     } else if (!vectorElementType) {
290       buildCstData = [&](StringRef data) {
291         return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
292                                                innermostLLVMType);
293       };
294     }
295   } else if (isa<VectorType>(type)) {
296     buildCstData = [&](StringRef data) {
297       return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
298                                               innermostLLVMType);
299     };
300   }
301   if (!buildCstData)
302     return nullptr;
303 
304   // Create innermost constants and defer to the default constant creation
305   // mechanism for other dimensions.
306   SmallVector<llvm::Constant *> constants;
307   unsigned aggregateSize = denseElementsAttr.getType().getShape().back() *
308                            (innermostLLVMType->getScalarSizeInBits() / 8);
309   constants.reserve(numAggregates);
310   for (unsigned i = 0; i < numAggregates; ++i) {
311     StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize,
312                    aggregateSize);
313     constants.push_back(buildCstData(data));
314   }
315 
316   ArrayRef<llvm::Constant *> constantsRef = constants;
317   return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
318 }
319 
320 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
321 /// This currently supports integer, floating point, splat and dense element
322 /// attributes and combinations thereof. Also, an array attribute with two
323 /// elements is supported to represent a complex constant.  In case of error,
324 /// report it to `loc` and return nullptr.
325 llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
326     llvm::Type *llvmType, Attribute attr, Location loc,
327     const ModuleTranslation &moduleTranslation) {
328   if (!attr)
329     return llvm::UndefValue::get(llvmType);
330   if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
331     auto arrayAttr = dyn_cast<ArrayAttr>(attr);
332     if (!arrayAttr || arrayAttr.size() != 2) {
333       emitError(loc, "expected struct type to be a complex number");
334       return nullptr;
335     }
336     llvm::Type *elementType = structType->getElementType(0);
337     llvm::Constant *real =
338         getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
339     if (!real)
340       return nullptr;
341     llvm::Constant *imag =
342         getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
343     if (!imag)
344       return nullptr;
345     return llvm::ConstantStruct::get(structType, {real, imag});
346   }
347   if (auto *targetExtType = dyn_cast<::llvm::TargetExtType>(llvmType)) {
348     // TODO: Replace with 'zeroinitializer' once there is a dedicated
349     // zeroinitializer operation in the LLVM dialect.
350     auto intAttr = dyn_cast<IntegerAttr>(attr);
351     if (!intAttr || intAttr.getInt() != 0)
352       emitError(loc,
353                 "Only zero-initialization allowed for target extension type");
354 
355     return llvm::ConstantTargetNone::get(targetExtType);
356   }
357   // For integer types, we allow a mismatch in sizes as the index type in
358   // MLIR might have a different size than the index type in the LLVM module.
359   if (auto intAttr = dyn_cast<IntegerAttr>(attr))
360     return llvm::ConstantInt::get(
361         llvmType,
362         intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
363   if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
364     const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
365     // Special case for 8-bit floats, which are represented by integers due to
366     // the lack of native fp8 types in LLVM at the moment. Additionally, handle
367     // targets (like AMDGPU) that don't implement bfloat and convert all bfloats
368     // to i16.
369     unsigned floatWidth = APFloat::getSizeInBits(sem);
370     if (llvmType->isIntegerTy(floatWidth))
371       return llvm::ConstantInt::get(llvmType,
372                                     floatAttr.getValue().bitcastToAPInt());
373     if (llvmType !=
374         llvm::Type::getFloatingPointTy(llvmType->getContext(),
375                                        floatAttr.getValue().getSemantics())) {
376       emitError(loc, "FloatAttr does not match expected type of the constant");
377       return nullptr;
378     }
379     return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
380   }
381   if (auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr))
382     return llvm::ConstantExpr::getBitCast(
383         moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType);
384   if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) {
385     llvm::Type *elementType;
386     uint64_t numElements;
387     bool isScalable = false;
388     if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
389       elementType = arrayTy->getElementType();
390       numElements = arrayTy->getNumElements();
391     } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
392       elementType = fVectorTy->getElementType();
393       numElements = fVectorTy->getNumElements();
394     } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
395       elementType = sVectorTy->getElementType();
396       numElements = sVectorTy->getMinNumElements();
397       isScalable = true;
398     } else {
399       llvm_unreachable("unrecognized constant vector type");
400     }
401     // Splat value is a scalar. Extract it only if the element type is not
402     // another sequence type. The recursion terminates because each step removes
403     // one outer sequential type.
404     bool elementTypeSequential =
405         isa<llvm::ArrayType, llvm::VectorType>(elementType);
406     llvm::Constant *child = getLLVMConstant(
407         elementType,
408         elementTypeSequential ? splatAttr
409                               : splatAttr.getSplatValue<Attribute>(),
410         loc, moduleTranslation);
411     if (!child)
412       return nullptr;
413     if (llvmType->isVectorTy())
414       return llvm::ConstantVector::getSplat(
415           llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
416     if (llvmType->isArrayTy()) {
417       auto *arrayType = llvm::ArrayType::get(elementType, numElements);
418       SmallVector<llvm::Constant *, 8> constants(numElements, child);
419       return llvm::ConstantArray::get(arrayType, constants);
420     }
421   }
422 
423   // Try using raw elements data if possible.
424   if (llvm::Constant *result =
425           convertDenseElementsAttr(loc, dyn_cast<DenseElementsAttr>(attr),
426                                    llvmType, moduleTranslation)) {
427     return result;
428   }
429 
430   // Fall back to element-by-element construction otherwise.
431   if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
432     assert(elementsAttr.getShapedType().hasStaticShape());
433     assert(!elementsAttr.getShapedType().getShape().empty() &&
434            "unexpected empty elements attribute shape");
435 
436     SmallVector<llvm::Constant *, 8> constants;
437     constants.reserve(elementsAttr.getNumElements());
438     llvm::Type *innermostType = getInnermostElementType(llvmType);
439     for (auto n : elementsAttr.getValues<Attribute>()) {
440       constants.push_back(
441           getLLVMConstant(innermostType, n, loc, moduleTranslation));
442       if (!constants.back())
443         return nullptr;
444     }
445     ArrayRef<llvm::Constant *> constantsRef = constants;
446     llvm::Constant *result = buildSequentialConstant(
447         constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc);
448     assert(constantsRef.empty() && "did not consume all elemental constants");
449     return result;
450   }
451 
452   if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
453     return llvm::ConstantDataArray::get(
454         moduleTranslation.getLLVMContext(),
455         ArrayRef<char>{stringAttr.getValue().data(),
456                        stringAttr.getValue().size()});
457   }
458   emitError(loc, "unsupported constant value");
459   return nullptr;
460 }
461 
462 ModuleTranslation::ModuleTranslation(Operation *module,
463                                      std::unique_ptr<llvm::Module> llvmModule)
464     : mlirModule(module), llvmModule(std::move(llvmModule)),
465       debugTranslation(
466           std::make_unique<DebugTranslation>(module, *this->llvmModule)),
467       loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>(
468           *this, *this->llvmModule)),
469       typeTranslator(this->llvmModule->getContext()),
470       iface(module->getContext()) {
471   assert(satisfiesLLVMModule(mlirModule) &&
472          "mlirModule should honor LLVM's module semantics.");
473 }
474 
475 ModuleTranslation::~ModuleTranslation() {
476   if (ompBuilder)
477     ompBuilder->finalize();
478 }
479 
480 void ModuleTranslation::forgetMapping(Region &region) {
481   SmallVector<Region *> toProcess;
482   toProcess.push_back(&region);
483   while (!toProcess.empty()) {
484     Region *current = toProcess.pop_back_val();
485     for (Block &block : *current) {
486       blockMapping.erase(&block);
487       for (Value arg : block.getArguments())
488         valueMapping.erase(arg);
489       for (Operation &op : block) {
490         for (Value value : op.getResults())
491           valueMapping.erase(value);
492         if (op.hasSuccessors())
493           branchMapping.erase(&op);
494         if (isa<LLVM::GlobalOp>(op))
495           globalsMapping.erase(&op);
496         llvm::append_range(
497             toProcess,
498             llvm::map_range(op.getRegions(), [](Region &r) { return &r; }));
499       }
500     }
501   }
502 }
503 
504 /// Get the SSA value passed to the current block from the terminator operation
505 /// of its predecessor.
506 static Value getPHISourceValue(Block *current, Block *pred,
507                                unsigned numArguments, unsigned index) {
508   Operation &terminator = *pred->getTerminator();
509   if (isa<LLVM::BrOp>(terminator))
510     return terminator.getOperand(index);
511 
512 #ifndef NDEBUG
513   llvm::SmallPtrSet<Block *, 4> seenSuccessors;
514   for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
515     Block *successor = terminator.getSuccessor(i);
516     auto branch = cast<BranchOpInterface>(terminator);
517     SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
518     assert(
519         (!seenSuccessors.contains(successor) || successorOperands.empty()) &&
520         "successors with arguments in LLVM branches must be different blocks");
521     seenSuccessors.insert(successor);
522   }
523 #endif
524 
525   // For instructions that branch based on a condition value, we need to take
526   // the operands for the branch that was taken.
527   if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
528     // For conditional branches, we take the operands from either the "true" or
529     // the "false" branch.
530     return condBranchOp.getSuccessor(0) == current
531                ? condBranchOp.getTrueDestOperands()[index]
532                : condBranchOp.getFalseDestOperands()[index];
533   }
534 
535   if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
536     // For switches, we take the operands from either the default case, or from
537     // the case branch that was taken.
538     if (switchOp.getDefaultDestination() == current)
539       return switchOp.getDefaultOperands()[index];
540     for (const auto &i : llvm::enumerate(switchOp.getCaseDestinations()))
541       if (i.value() == current)
542         return switchOp.getCaseOperands(i.index())[index];
543   }
544 
545   if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) {
546     return invokeOp.getNormalDest() == current
547                ? invokeOp.getNormalDestOperands()[index]
548                : invokeOp.getUnwindDestOperands()[index];
549   }
550 
551   llvm_unreachable(
552       "only branch, switch or invoke operations can be terminators "
553       "of a block that has successors");
554 }
555 
556 /// Connect the PHI nodes to the results of preceding blocks.
557 void mlir::LLVM::detail::connectPHINodes(Region &region,
558                                          const ModuleTranslation &state) {
559   // Skip the first block, it cannot be branched to and its arguments correspond
560   // to the arguments of the LLVM function.
561   for (Block &bb : llvm::drop_begin(region)) {
562     llvm::BasicBlock *llvmBB = state.lookupBlock(&bb);
563     auto phis = llvmBB->phis();
564     auto numArguments = bb.getNumArguments();
565     assert(numArguments == std::distance(phis.begin(), phis.end()));
566     for (auto [index, phiNode] : llvm::enumerate(phis)) {
567       for (auto *pred : bb.getPredecessors()) {
568         // Find the LLVM IR block that contains the converted terminator
569         // instruction and use it in the PHI node. Note that this block is not
570         // necessarily the same as state.lookupBlock(pred), some operations
571         // (in particular, OpenMP operations using OpenMPIRBuilder) may have
572         // split the blocks.
573         llvm::Instruction *terminator =
574             state.lookupBranch(pred->getTerminator());
575         assert(terminator && "missing the mapping for a terminator");
576         phiNode.addIncoming(state.lookupValue(getPHISourceValue(
577                                 &bb, pred, numArguments, index)),
578                             terminator->getParent());
579       }
580     }
581   }
582 }
583 
584 /// Sort function blocks topologically.
585 SetVector<Block *>
586 mlir::LLVM::detail::getTopologicallySortedBlocks(Region &region) {
587   // For each block that has not been visited yet (i.e. that has no
588   // predecessors), add it to the list as well as its successors.
589   SetVector<Block *> blocks;
590   for (Block &b : region) {
591     if (blocks.count(&b) == 0) {
592       llvm::ReversePostOrderTraversal<Block *> traversal(&b);
593       blocks.insert(traversal.begin(), traversal.end());
594     }
595   }
596   assert(blocks.size() == region.getBlocks().size() &&
597          "some blocks are not sorted");
598 
599   return blocks;
600 }
601 
602 llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
603     llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
604     ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) {
605   llvm::Module *module = builder.GetInsertBlock()->getModule();
606   llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, tys);
607   return builder.CreateCall(fn, args);
608 }
609 
610 /// Given a single MLIR operation, create the corresponding LLVM IR operation
611 /// using the `builder`.
612 LogicalResult
613 ModuleTranslation::convertOperation(Operation &op,
614                                     llvm::IRBuilderBase &builder) {
615   const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op);
616   if (!opIface)
617     return op.emitError("cannot be converted to LLVM IR: missing "
618                         "`LLVMTranslationDialectInterface` registration for "
619                         "dialect for op: ")
620            << op.getName();
621 
622   if (failed(opIface->convertOperation(&op, builder, *this)))
623     return op.emitError("LLVM Translation failed for operation: ")
624            << op.getName();
625 
626   return convertDialectAttributes(&op);
627 }
628 
629 /// Convert block to LLVM IR.  Unless `ignoreArguments` is set, emit PHI nodes
630 /// to define values corresponding to the MLIR block arguments.  These nodes
631 /// are not connected to the source basic blocks, which may not exist yet.  Uses
632 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
633 /// been created for `bb` and included in the block mapping.  Inserts new
634 /// instructions at the end of the block and leaves `builder` in a state
635 /// suitable for further insertion into the end of the block.
636 LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
637                                               llvm::IRBuilderBase &builder) {
638   builder.SetInsertPoint(lookupBlock(&bb));
639   auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
640 
641   // Before traversing operations, make block arguments available through
642   // value remapping and PHI nodes, but do not add incoming edges for the PHI
643   // nodes just yet: those values may be defined by this or following blocks.
644   // This step is omitted if "ignoreArguments" is set.  The arguments of the
645   // first block have been already made available through the remapping of
646   // LLVM function arguments.
647   if (!ignoreArguments) {
648     auto predecessors = bb.getPredecessors();
649     unsigned numPredecessors =
650         std::distance(predecessors.begin(), predecessors.end());
651     for (auto arg : bb.getArguments()) {
652       auto wrappedType = arg.getType();
653       if (!isCompatibleType(wrappedType))
654         return emitError(bb.front().getLoc(),
655                          "block argument does not have an LLVM type");
656       llvm::Type *type = convertType(wrappedType);
657       llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
658       mapValue(arg, phi);
659     }
660   }
661 
662   // Traverse operations.
663   for (auto &op : bb) {
664     // Set the current debug location within the builder.
665     builder.SetCurrentDebugLocation(
666         debugTranslation->translateLoc(op.getLoc(), subprogram));
667 
668     if (failed(convertOperation(op, builder)))
669       return failure();
670 
671     // Set the branch weight metadata on the translated instruction.
672     if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
673       setBranchWeightsMetadata(iface);
674   }
675 
676   return success();
677 }
678 
679 /// A helper method to get the single Block in an operation honoring LLVM's
680 /// module requirements.
681 static Block &getModuleBody(Operation *module) {
682   return module->getRegion(0).front();
683 }
684 
685 /// A helper method to decide if a constant must not be set as a global variable
686 /// initializer. For an external linkage variable, the variable with an
687 /// initializer is considered externally visible and defined in this module, the
688 /// variable without an initializer is externally available and is defined
689 /// elsewhere.
690 static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage,
691                                         llvm::Constant *cst) {
692   return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) ||
693          linkage == llvm::GlobalVariable::ExternalWeakLinkage;
694 }
695 
696 /// Sets the runtime preemption specifier of `gv` to dso_local if
697 /// `dsoLocalRequested` is true, otherwise it is left unchanged.
698 static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
699                                           llvm::GlobalValue *gv) {
700   if (dsoLocalRequested)
701     gv->setDSOLocal(true);
702 }
703 
704 /// Create named global variables that correspond to llvm.mlir.global
705 /// definitions. Convert llvm.global_ctors and global_dtors ops.
706 LogicalResult ModuleTranslation::convertGlobals() {
707   for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
708     llvm::Type *type = convertType(op.getType());
709     llvm::Constant *cst = nullptr;
710     if (op.getValueOrNull()) {
711       // String attributes are treated separately because they cannot appear as
712       // in-function constants and are thus not supported by getLLVMConstant.
713       if (auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) {
714         cst = llvm::ConstantDataArray::getString(
715             llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
716         type = cst->getType();
717       } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc(),
718                                          *this))) {
719         return failure();
720       }
721     }
722 
723     auto linkage = convertLinkageToLLVM(op.getLinkage());
724     auto addrSpace = op.getAddrSpace();
725 
726     // LLVM IR requires constant with linkage other than external or weak
727     // external to have initializers. If MLIR does not provide an initializer,
728     // default to undef.
729     bool dropInitializer = shouldDropGlobalInitializer(linkage, cst);
730     if (!dropInitializer && !cst)
731       cst = llvm::UndefValue::get(type);
732     else if (dropInitializer && cst)
733       cst = nullptr;
734 
735     auto *var = new llvm::GlobalVariable(
736         *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(),
737         /*InsertBefore=*/nullptr,
738         op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
739                              : llvm::GlobalValue::NotThreadLocal,
740         addrSpace);
741 
742     if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) {
743       auto selectorOp = cast<ComdatSelectorOp>(
744           SymbolTable::lookupNearestSymbolFrom(op, *comdat));
745       var->setComdat(comdatMapping.lookup(selectorOp));
746     }
747 
748     if (op.getUnnamedAddr().has_value())
749       var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr()));
750 
751     if (op.getSection().has_value())
752       var->setSection(*op.getSection());
753 
754     addRuntimePreemptionSpecifier(op.getDsoLocal(), var);
755 
756     std::optional<uint64_t> alignment = op.getAlignment();
757     if (alignment.has_value())
758       var->setAlignment(llvm::MaybeAlign(alignment.value()));
759 
760     var->setVisibility(convertVisibilityToLLVM(op.getVisibility_()));
761 
762     globalsMapping.try_emplace(op, var);
763   }
764 
765   // Convert global variable bodies. This is done after all global variables
766   // have been created in LLVM IR because a global body may refer to another
767   // global or itself. So all global variables need to be mapped first.
768   for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
769     if (Block *initializer = op.getInitializerBlock()) {
770       llvm::IRBuilder<> builder(llvmModule->getContext());
771       for (auto &op : initializer->without_terminator()) {
772         if (failed(convertOperation(op, builder)) ||
773             !isa<llvm::Constant>(lookupValue(op.getResult(0))))
774           return emitError(op.getLoc(), "unemittable constant value");
775       }
776       ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
777       llvm::Constant *cst =
778           cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
779       auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
780       if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
781         global->setInitializer(cst);
782     }
783   }
784 
785   // Convert llvm.mlir.global_ctors and dtors.
786   for (Operation &op : getModuleBody(mlirModule)) {
787     auto ctorOp = dyn_cast<GlobalCtorsOp>(op);
788     auto dtorOp = dyn_cast<GlobalDtorsOp>(op);
789     if (!ctorOp && !dtorOp)
790       continue;
791     auto range = ctorOp ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities())
792                         : llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities());
793     auto appendGlobalFn =
794         ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
795     for (auto symbolAndPriority : range) {
796       llvm::Function *f = lookupFunction(
797           cast<FlatSymbolRefAttr>(std::get<0>(symbolAndPriority)).getValue());
798       appendGlobalFn(*llvmModule, f,
799                      cast<IntegerAttr>(std::get<1>(symbolAndPriority)).getInt(),
800                      /*Data=*/nullptr);
801     }
802   }
803 
804   for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
805     if (failed(convertDialectAttributes(op)))
806       return failure();
807 
808   return success();
809 }
810 
811 /// Attempts to add an attribute identified by `key`, optionally with the given
812 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
813 /// attribute has a kind known to LLVM IR, create the attribute of this kind,
814 /// otherwise keep it as a string attribute. Performs additional checks for
815 /// attributes known to have or not have a value in order to avoid assertions
816 /// inside LLVM upon construction.
817 static LogicalResult checkedAddLLVMFnAttribute(Location loc,
818                                                llvm::Function *llvmFunc,
819                                                StringRef key,
820                                                StringRef value = StringRef()) {
821   auto kind = llvm::Attribute::getAttrKindFromName(key);
822   if (kind == llvm::Attribute::None) {
823     llvmFunc->addFnAttr(key, value);
824     return success();
825   }
826 
827   if (llvm::Attribute::isIntAttrKind(kind)) {
828     if (value.empty())
829       return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
830 
831     int64_t result;
832     if (!value.getAsInteger(/*Radix=*/0, result))
833       llvmFunc->addFnAttr(
834           llvm::Attribute::get(llvmFunc->getContext(), kind, result));
835     else
836       llvmFunc->addFnAttr(key, value);
837     return success();
838   }
839 
840   if (!value.empty())
841     return emitError(loc) << "LLVM attribute '" << key
842                           << "' does not expect a value, found '" << value
843                           << "'";
844 
845   llvmFunc->addFnAttr(kind);
846   return success();
847 }
848 
849 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
850 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
851 /// to be an array attribute containing either string attributes, treated as
852 /// value-less LLVM attributes, or array attributes containing two string
853 /// attributes, with the first string being the name of the corresponding LLVM
854 /// attribute and the second string beings its value. Note that even integer
855 /// attributes are expected to have their values expressed as strings.
856 static LogicalResult
857 forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes,
858                              llvm::Function *llvmFunc) {
859   if (!attributes)
860     return success();
861 
862   for (Attribute attr : *attributes) {
863     if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
864       if (failed(
865               checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue())))
866         return failure();
867       continue;
868     }
869 
870     auto arrayAttr = dyn_cast<ArrayAttr>(attr);
871     if (!arrayAttr || arrayAttr.size() != 2)
872       return emitError(loc)
873              << "expected 'passthrough' to contain string or array attributes";
874 
875     auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
876     auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
877     if (!keyAttr || !valueAttr)
878       return emitError(loc)
879              << "expected arrays within 'passthrough' to contain two strings";
880 
881     if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(),
882                                          valueAttr.getValue())))
883       return failure();
884   }
885   return success();
886 }
887 
888 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
889   // Clear the block, branch value mappings, they are only relevant within one
890   // function.
891   blockMapping.clear();
892   valueMapping.clear();
893   branchMapping.clear();
894   llvm::Function *llvmFunc = lookupFunction(func.getName());
895 
896   // Translate the debug information for this function.
897   debugTranslation->translate(func, *llvmFunc);
898 
899   // Add function arguments to the value remapping table.
900   for (auto [mlirArg, llvmArg] :
901        llvm::zip(func.getArguments(), llvmFunc->args()))
902     mapValue(mlirArg, &llvmArg);
903 
904   // Check the personality and set it.
905   if (func.getPersonality()) {
906     llvm::Type *ty = llvm::Type::getInt8PtrTy(llvmFunc->getContext());
907     if (llvm::Constant *pfunc = getLLVMConstant(ty, func.getPersonalityAttr(),
908                                                 func.getLoc(), *this))
909       llvmFunc->setPersonalityFn(pfunc);
910   }
911 
912   if (std::optional<StringRef> section = func.getSection())
913     llvmFunc->setSection(*section);
914 
915   if (func.getArmStreaming())
916     llvmFunc->addFnAttr("aarch64_pstate_sm_enabled");
917   else if (func.getArmLocallyStreaming())
918     llvmFunc->addFnAttr("aarch64_pstate_sm_body");
919 
920   // First, create all blocks so we can jump to them.
921   llvm::LLVMContext &llvmContext = llvmFunc->getContext();
922   for (auto &bb : func) {
923     auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
924     llvmBB->insertInto(llvmFunc);
925     mapBlock(&bb, llvmBB);
926   }
927 
928   // Then, convert blocks one by one in topological order to ensure defs are
929   // converted before uses.
930   auto blocks = detail::getTopologicallySortedBlocks(func.getBody());
931   for (Block *bb : blocks) {
932     llvm::IRBuilder<> builder(llvmContext);
933     if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
934       return failure();
935   }
936 
937   // After all blocks have been traversed and values mapped, connect the PHI
938   // nodes to the results of preceding blocks.
939   detail::connectPHINodes(func.getBody(), *this);
940 
941   // Finally, convert dialect attributes attached to the function.
942   return convertDialectAttributes(func);
943 }
944 
945 LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) {
946   for (NamedAttribute attribute : op->getDialectAttrs())
947     if (failed(iface.amendOperation(op, attribute, *this)))
948       return failure();
949   return success();
950 }
951 
952 /// Converts the function attributes from LLVMFuncOp and attaches them to the
953 /// llvm::Function.
954 static void convertFunctionAttributes(LLVMFuncOp func,
955                                       llvm::Function *llvmFunc) {
956   if (!func.getMemory())
957     return;
958 
959   MemoryEffectsAttr memEffects = func.getMemoryAttr();
960 
961   // Add memory effects incrementally.
962   llvm::MemoryEffects newMemEffects =
963       llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
964                           convertModRefInfoToLLVM(memEffects.getArgMem()));
965   newMemEffects |= llvm::MemoryEffects(
966       llvm::MemoryEffects::Location::InaccessibleMem,
967       convertModRefInfoToLLVM(memEffects.getInaccessibleMem()));
968   newMemEffects |=
969       llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
970                           convertModRefInfoToLLVM(memEffects.getOther()));
971   llvmFunc->setMemoryEffects(newMemEffects);
972 }
973 
974 llvm::AttrBuilder
975 ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
976   llvm::AttrBuilder attrBuilder(llvmModule->getContext());
977 
978   for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
979     Attribute attr = paramAttrs.get(mlirName);
980     // Skip attributes that are not present.
981     if (!attr)
982       continue;
983 
984     // NOTE: C++17 does not support capturing structured bindings.
985     llvm::Attribute::AttrKind llvmKindCap = llvmKind;
986 
987     llvm::TypeSwitch<Attribute>(attr)
988         .Case<TypeAttr>([&](auto typeAttr) {
989           attrBuilder.addTypeAttr(llvmKindCap,
990                                   convertType(typeAttr.getValue()));
991         })
992         .Case<IntegerAttr>([&](auto intAttr) {
993           attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
994         })
995         .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); });
996   }
997 
998   return attrBuilder;
999 }
1000 
1001 LogicalResult ModuleTranslation::convertFunctionSignatures() {
1002   // Declare all functions first because there may be function calls that form a
1003   // call graph with cycles, or global initializers that reference functions.
1004   for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1005     llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
1006         function.getName(),
1007         cast<llvm::FunctionType>(convertType(function.getFunctionType())));
1008     llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
1009     llvmFunc->setLinkage(convertLinkageToLLVM(function.getLinkage()));
1010     llvmFunc->setCallingConv(convertCConvToLLVM(function.getCConv()));
1011     mapFunction(function.getName(), llvmFunc);
1012     addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc);
1013 
1014     // Convert function attributes.
1015     convertFunctionAttributes(function, llvmFunc);
1016 
1017     // Convert function_entry_count attribute to metadata.
1018     if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
1019       llvmFunc->setEntryCount(entryCount.value());
1020 
1021     // Convert result attributes.
1022     if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
1023       DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
1024       llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
1025     }
1026 
1027     // Convert argument attributes.
1028     for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
1029       if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
1030         llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
1031         llvmArg.addAttrs(attrBuilder);
1032       }
1033     }
1034 
1035     // Forward the pass-through attributes to LLVM.
1036     if (failed(forwardPassthroughAttributes(
1037             function.getLoc(), function.getPassthrough(), llvmFunc)))
1038       return failure();
1039 
1040     // Convert visibility attribute.
1041     llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_()));
1042 
1043     // Convert the comdat attribute.
1044     if (std::optional<mlir::SymbolRefAttr> comdat = function.getComdat()) {
1045       auto selectorOp = cast<ComdatSelectorOp>(
1046           SymbolTable::lookupNearestSymbolFrom(function, *comdat));
1047       llvmFunc->setComdat(comdatMapping.lookup(selectorOp));
1048     }
1049 
1050     if (auto gc = function.getGarbageCollector())
1051       llvmFunc->setGC(gc->str());
1052 
1053     if (auto unnamedAddr = function.getUnnamedAddr())
1054       llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr));
1055 
1056     if (auto alignment = function.getAlignment())
1057       llvmFunc->setAlignment(llvm::MaybeAlign(*alignment));
1058   }
1059 
1060   return success();
1061 }
1062 
1063 LogicalResult ModuleTranslation::convertFunctions() {
1064   // Convert functions.
1065   for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1066     // Do not convert external functions, but do process dialect attributes
1067     // attached to them.
1068     if (function.isExternal()) {
1069       if (failed(convertDialectAttributes(function)))
1070         return failure();
1071       continue;
1072     }
1073 
1074     if (failed(convertOneFunction(function)))
1075       return failure();
1076   }
1077 
1078   return success();
1079 }
1080 
1081 LogicalResult ModuleTranslation::convertComdats() {
1082   for (auto comdatOp : getModuleBody(mlirModule).getOps<ComdatOp>()) {
1083     for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) {
1084       llvm::Module *module = getLLVMModule();
1085       if (module->getComdatSymbolTable().contains(selectorOp.getSymName()))
1086         return emitError(selectorOp.getLoc())
1087                << "comdat selection symbols must be unique even in different "
1088                   "comdat regions";
1089       llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName());
1090       comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat()));
1091       comdatMapping.try_emplace(selectorOp, comdat);
1092     }
1093   }
1094   return success();
1095 }
1096 
1097 void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op,
1098                                                 llvm::Instruction *inst) {
1099   if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op))
1100     inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
1101 }
1102 
1103 llvm::MDNode *
1104 ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) {
1105   auto [scopeIt, scopeInserted] =
1106       aliasScopeMetadataMapping.try_emplace(aliasScopeAttr, nullptr);
1107   if (!scopeInserted)
1108     return scopeIt->second;
1109   llvm::LLVMContext &ctx = llvmModule->getContext();
1110   // Convert the domain metadata node if necessary.
1111   auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace(
1112       aliasScopeAttr.getDomain(), nullptr);
1113   if (insertedDomain) {
1114     llvm::SmallVector<llvm::Metadata *, 2> operands;
1115     // Placeholder for self-reference.
1116     operands.push_back({});
1117     if (StringAttr description = aliasScopeAttr.getDomain().getDescription())
1118       operands.push_back(llvm::MDString::get(ctx, description));
1119     domainIt->second = llvm::MDNode::get(ctx, operands);
1120     // Self-reference for uniqueness.
1121     domainIt->second->replaceOperandWith(0, domainIt->second);
1122   }
1123   // Convert the scope metadata node.
1124   assert(domainIt->second && "Scope's domain should already be valid");
1125   llvm::SmallVector<llvm::Metadata *, 3> operands;
1126   // Placeholder for self-reference.
1127   operands.push_back({});
1128   operands.push_back(domainIt->second);
1129   if (StringAttr description = aliasScopeAttr.getDescription())
1130     operands.push_back(llvm::MDString::get(ctx, description));
1131   scopeIt->second = llvm::MDNode::get(ctx, operands);
1132   // Self-reference for uniqueness.
1133   scopeIt->second->replaceOperandWith(0, scopeIt->second);
1134   return scopeIt->second;
1135 }
1136 
1137 llvm::MDNode *ModuleTranslation::getOrCreateAliasScopes(
1138     ArrayRef<AliasScopeAttr> aliasScopeAttrs) {
1139   SmallVector<llvm::Metadata *> nodes;
1140   nodes.reserve(aliasScopeAttrs.size());
1141   for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs)
1142     nodes.push_back(getOrCreateAliasScope(aliasScopeAttr));
1143   return llvm::MDNode::get(getLLVMContext(), nodes);
1144 }
1145 
1146 void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op,
1147                                               llvm::Instruction *inst) {
1148   auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs, unsigned kind) {
1149     if (!aliasScopeAttrs || aliasScopeAttrs.empty())
1150       return;
1151     llvm::MDNode *node = getOrCreateAliasScopes(
1152         llvm::to_vector(aliasScopeAttrs.getAsRange<AliasScopeAttr>()));
1153     inst->setMetadata(kind, node);
1154   };
1155 
1156   populateScopeMetadata(op.getAliasScopesOrNull(),
1157                         llvm::LLVMContext::MD_alias_scope);
1158   populateScopeMetadata(op.getNoAliasScopesOrNull(),
1159                         llvm::LLVMContext::MD_noalias);
1160 }
1161 
1162 llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
1163   return tbaaMetadataMapping.lookup(tbaaAttr);
1164 }
1165 
1166 void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
1167                                         llvm::Instruction *inst) {
1168   ArrayAttr tagRefs = op.getTBAATagsOrNull();
1169   if (!tagRefs || tagRefs.empty())
1170     return;
1171 
1172   // LLVM IR currently does not support attaching more than one TBAA access tag
1173   // to a memory accessing instruction. It may be useful to support this in
1174   // future, but for the time being just ignore the metadata if MLIR operation
1175   // has multiple access tags.
1176   if (tagRefs.size() > 1) {
1177     op.emitWarning() << "TBAA access tags were not translated, because LLVM "
1178                         "IR only supports a single tag per instruction";
1179     return;
1180   }
1181 
1182   llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
1183   inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
1184 }
1185 
1186 void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
1187   DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
1188   if (!weightsAttr)
1189     return;
1190 
1191   llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
1192   assert(inst && "expected the operation to have a mapping to an instruction");
1193   SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
1194   inst->setMetadata(
1195       llvm::LLVMContext::MD_prof,
1196       llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));
1197 }
1198 
1199 LogicalResult ModuleTranslation::createTBAAMetadata() {
1200   llvm::LLVMContext &ctx = llvmModule->getContext();
1201   llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);
1202 
1203   // Walk the entire module and create all metadata nodes for the TBAA
1204   // attributes. The code below relies on two invariants of the
1205   // `AttrTypeWalker`:
1206   // 1. Attributes are visited in post-order: Since the attributes create a DAG,
1207   //    this ensures that any lookups into `tbaaMetadataMapping` for child
1208   //    attributes succeed.
1209   // 2. Attributes are only ever visited once: This way we don't leak any
1210   //    LLVM metadata instances.
1211   AttrTypeWalker walker;
1212   walker.addWalk([&](TBAARootAttr root) {
1213     tbaaMetadataMapping.insert(
1214         {root, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, root.getId()))});
1215   });
1216 
1217   walker.addWalk([&](TBAATypeDescriptorAttr descriptor) {
1218     SmallVector<llvm::Metadata *> operands;
1219     operands.push_back(llvm::MDString::get(ctx, descriptor.getId()));
1220     for (TBAAMemberAttr member : descriptor.getMembers()) {
1221       operands.push_back(tbaaMetadataMapping.lookup(member.getTypeDesc()));
1222       operands.push_back(llvm::ConstantAsMetadata::get(
1223           llvm::ConstantInt::get(offsetTy, member.getOffset())));
1224     }
1225 
1226     tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
1227   });
1228 
1229   walker.addWalk([&](TBAATagAttr tag) {
1230     SmallVector<llvm::Metadata *> operands;
1231 
1232     operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType()));
1233     operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType()));
1234 
1235     operands.push_back(llvm::ConstantAsMetadata::get(
1236         llvm::ConstantInt::get(offsetTy, tag.getOffset())));
1237     if (tag.getConstant())
1238       operands.push_back(
1239           llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(offsetTy, 1)));
1240 
1241     tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
1242   });
1243 
1244   mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) {
1245     if (auto attr = analysisOpInterface.getTBAATagsOrNull())
1246       walker.walk(attr);
1247   });
1248 
1249   return success();
1250 }
1251 
1252 void ModuleTranslation::setLoopMetadata(Operation *op,
1253                                         llvm::Instruction *inst) {
1254   LoopAnnotationAttr attr =
1255       TypeSwitch<Operation *, LoopAnnotationAttr>(op)
1256           .Case<LLVM::BrOp, LLVM::CondBrOp>(
1257               [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); });
1258   if (!attr)
1259     return;
1260   llvm::MDNode *loopMD =
1261       loopAnnotationTranslation->translateLoopAnnotation(attr, op);
1262   inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
1263 }
1264 
1265 llvm::Type *ModuleTranslation::convertType(Type type) {
1266   return typeTranslator.translateType(type);
1267 }
1268 
1269 /// A helper to look up remapped operands in the value remapping table.
1270 SmallVector<llvm::Value *> ModuleTranslation::lookupValues(ValueRange values) {
1271   SmallVector<llvm::Value *> remapped;
1272   remapped.reserve(values.size());
1273   for (Value v : values)
1274     remapped.push_back(lookupValue(v));
1275   return remapped;
1276 }
1277 
1278 llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() {
1279   if (!ompBuilder) {
1280     ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
1281 
1282     bool isTargetDevice = false, isGPU = false;
1283     llvm::StringRef hostIRFilePath = "";
1284 
1285     if (auto deviceAttr =
1286             mlirModule->getAttrOfType<mlir::BoolAttr>("omp.is_target_device"))
1287       isTargetDevice = deviceAttr.getValue();
1288 
1289     if (auto gpuAttr = mlirModule->getAttrOfType<mlir::BoolAttr>("omp.is_gpu"))
1290       isGPU = gpuAttr.getValue();
1291 
1292     if (auto filepathAttr =
1293             mlirModule->getAttrOfType<mlir::StringAttr>("omp.host_ir_filepath"))
1294       hostIRFilePath = filepathAttr.getValue();
1295 
1296     ompBuilder->initialize(hostIRFilePath);
1297 
1298     // TODO: set the flags when available
1299     llvm::OpenMPIRBuilderConfig config(
1300         isTargetDevice, isGPU,
1301         /* HasRequiresUnifiedSharedMemory */ false,
1302         /* OpenMPOffloadMandatory */ false);
1303     ompBuilder->setConfig(config);
1304   }
1305   return ompBuilder.get();
1306 }
1307 
1308 llvm::DILocation *ModuleTranslation::translateLoc(Location loc,
1309                                                   llvm::DILocalScope *scope) {
1310   return debugTranslation->translateLoc(loc, scope);
1311 }
1312 
1313 llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) {
1314   return debugTranslation->translate(attr);
1315 }
1316 
1317 llvm::NamedMDNode *
1318 ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
1319   return llvmModule->getOrInsertNamedMetadata(name);
1320 }
1321 
1322 void ModuleTranslation::StackFrame::anchor() {}
1323 
1324 static std::unique_ptr<llvm::Module>
1325 prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
1326                   StringRef name) {
1327   m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
1328   auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
1329   if (auto dataLayoutAttr =
1330           m->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
1331     llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue());
1332   } else {
1333     FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout(""));
1334     if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) {
1335       if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) {
1336         llvmDataLayout =
1337             translateDataLayout(spec, DataLayout(iface), m->getLoc());
1338       }
1339     } else if (auto mod = dyn_cast<ModuleOp>(m)) {
1340       if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) {
1341         llvmDataLayout =
1342             translateDataLayout(spec, DataLayout(mod), m->getLoc());
1343       }
1344     }
1345     if (failed(llvmDataLayout))
1346       return nullptr;
1347     llvmModule->setDataLayout(*llvmDataLayout);
1348   }
1349   if (auto targetTripleAttr =
1350           m->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
1351     llvmModule->setTargetTriple(cast<StringAttr>(targetTripleAttr).getValue());
1352 
1353   // Inject declarations for `malloc` and `free` functions that can be used in
1354   // memref allocation/deallocation coming from standard ops lowering.
1355   llvm::IRBuilder<> builder(llvmContext);
1356   llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(),
1357                                   builder.getInt64Ty());
1358   llvmModule->getOrInsertFunction("free", builder.getVoidTy(),
1359                                   builder.getInt8PtrTy());
1360 
1361   return llvmModule;
1362 }
1363 
1364 std::unique_ptr<llvm::Module>
1365 mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
1366                               StringRef name) {
1367   if (!satisfiesLLVMModule(module)) {
1368     module->emitOpError("can not be translated to an LLVMIR module");
1369     return nullptr;
1370   }
1371 
1372   std::unique_ptr<llvm::Module> llvmModule =
1373       prepareLLVMModule(module, llvmContext, name);
1374   if (!llvmModule)
1375     return nullptr;
1376 
1377   LLVM::ensureDistinctSuccessors(module);
1378 
1379   ModuleTranslation translator(module, std::move(llvmModule));
1380   llvm::IRBuilder<> llvmBuilder(llvmContext);
1381 
1382   // Convert module before functions and operations inside, so dialect
1383   // attributes can be used to change dialect-specific global configurations via
1384   // `amendOperation()`. These configurations can then influence the translation
1385   // of operations afterwards.
1386   if (failed(translator.convertOperation(*module, llvmBuilder)))
1387     return nullptr;
1388 
1389   if (failed(translator.convertComdats()))
1390     return nullptr;
1391   if (failed(translator.convertFunctionSignatures()))
1392     return nullptr;
1393   if (failed(translator.convertGlobals()))
1394     return nullptr;
1395   if (failed(translator.createTBAAMetadata()))
1396     return nullptr;
1397 
1398   // Convert other top-level operations if possible.
1399   for (Operation &o : getModuleBody(module).getOperations()) {
1400     if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp,
1401              LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) &&
1402         !o.hasTrait<OpTrait::IsTerminator>() &&
1403         failed(translator.convertOperation(o, llvmBuilder))) {
1404       return nullptr;
1405     }
1406   }
1407 
1408   // Operations in function bodies with symbolic references must be converted
1409   // after the top-level operations they refer to are declared, so we do it
1410   // last.
1411   if (failed(translator.convertFunctions()))
1412     return nullptr;
1413 
1414   if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))
1415     return nullptr;
1416 
1417   return std::move(translator.llvmModule);
1418 }
1419