xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (revision e504ece6c15fa5b347a4d8ff7e6fc98ee109660e)
1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "TypeDetail.h"
16 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
17 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
18 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/Interfaces/FunctionImplementation.h"
27 #include "mlir/Transforms/InliningUtils.h"
28 
29 #include "llvm/ADT/SCCIterator.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/AsmParser/Parser.h"
32 #include "llvm/Bitcode/BitcodeReader.h"
33 #include "llvm/Bitcode/BitcodeWriter.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Error.h"
38 #include "llvm/Support/Mutex.h"
39 #include "llvm/Support/SourceMgr.h"
40 
41 #include <numeric>
42 #include <optional>
43 
44 using namespace mlir;
45 using namespace mlir::LLVM;
46 using mlir::LLVM::cconv::getMaxEnumValForCConv;
47 using mlir::LLVM::linkage::getMaxEnumValForLinkage;
48 using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
49 
50 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
51 
52 //===----------------------------------------------------------------------===//
53 // Property Helpers
54 //===----------------------------------------------------------------------===//
55 
56 //===----------------------------------------------------------------------===//
57 // IntegerOverflowFlags
58 
59 namespace mlir {
60 static Attribute convertToAttribute(MLIRContext *ctx,
61                                     IntegerOverflowFlags flags) {
62   return IntegerOverflowFlagsAttr::get(ctx, flags);
63 }
64 
65 static LogicalResult
66 convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr,
67                      function_ref<InFlightDiagnostic()> emitError) {
68   auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr);
69   if (!flagsAttr) {
70     return emitError() << "expected 'overflowFlags' attribute to be an "
71                           "IntegerOverflowFlagsAttr, but got "
72                        << attr;
73   }
74   flags = flagsAttr.getValue();
75   return success();
76 }
77 } // namespace mlir
78 
79 static ParseResult parseOverflowFlags(AsmParser &p,
80                                       IntegerOverflowFlags &flags) {
81   if (failed(p.parseOptionalKeyword("overflow"))) {
82     flags = IntegerOverflowFlags::none;
83     return success();
84   }
85   if (p.parseLess())
86     return failure();
87   do {
88     StringRef kw;
89     SMLoc loc = p.getCurrentLocation();
90     if (p.parseKeyword(&kw))
91       return failure();
92     std::optional<IntegerOverflowFlags> flag =
93         symbolizeIntegerOverflowFlags(kw);
94     if (!flag)
95       return p.emitError(loc,
96                          "invalid overflow flag: expected nsw, nuw, or none");
97     flags = flags | *flag;
98   } while (succeeded(p.parseOptionalComma()));
99   return p.parseGreater();
100 }
101 
102 static void printOverflowFlags(AsmPrinter &p, Operation *op,
103                                IntegerOverflowFlags flags) {
104   if (flags == IntegerOverflowFlags::none)
105     return;
106   p << " overflow<";
107   SmallVector<StringRef, 2> strs;
108   if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw))
109     strs.push_back("nsw");
110   if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw))
111     strs.push_back("nuw");
112   llvm::interleaveComma(strs, p);
113   p << ">";
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // Attribute Helpers
118 //===----------------------------------------------------------------------===//
119 
120 static constexpr const char kElemTypeAttrName[] = "elem_type";
121 
122 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
123   SmallVector<NamedAttribute, 8> filteredAttrs(
124       llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
125         if (attr.getName() == "fastmathFlags") {
126           auto defAttr =
127               FastmathFlagsAttr::get(attr.getValue().getContext(), {});
128           return defAttr != attr.getValue();
129         }
130         return true;
131       }));
132   return filteredAttrs;
133 }
134 
135 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
136 /// fully defined llvm.func.
137 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
138                                          Operation *op,
139                                          SymbolTableCollection &symbolTable) {
140   StringRef name = symbol.getValue();
141   auto func =
142       symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
143   if (!func)
144     return op->emitOpError("'")
145            << name << "' does not reference a valid LLVM function";
146   if (func.isExternal())
147     return op->emitOpError("'") << name << "' does not have a definition";
148   return success();
149 }
150 
151 /// Returns a boolean type that has the same shape as `type`. It supports both
152 /// fixed size vectors as well as scalable vectors.
153 static Type getI1SameShape(Type type) {
154   Type i1Type = IntegerType::get(type.getContext(), 1);
155   if (LLVM::isCompatibleVectorType(type))
156     return LLVM::getVectorType(i1Type, LLVM::getVectorNumElements(type));
157   return i1Type;
158 }
159 
160 // Parses one of the keywords provided in the list `keywords` and returns the
161 // position of the parsed keyword in the list. If none of the keywords from the
162 // list is parsed, returns -1.
163 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
164                                            ArrayRef<StringRef> keywords) {
165   for (const auto &en : llvm::enumerate(keywords)) {
166     if (succeeded(parser.parseOptionalKeyword(en.value())))
167       return en.index();
168   }
169   return -1;
170 }
171 
172 namespace {
173 template <typename Ty>
174 struct EnumTraits {};
175 
176 #define REGISTER_ENUM_TYPE(Ty)                                                 \
177   template <>                                                                  \
178   struct EnumTraits<Ty> {                                                      \
179     static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
180     static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
181   }
182 
183 REGISTER_ENUM_TYPE(Linkage);
184 REGISTER_ENUM_TYPE(UnnamedAddr);
185 REGISTER_ENUM_TYPE(CConv);
186 REGISTER_ENUM_TYPE(TailCallKind);
187 REGISTER_ENUM_TYPE(Visibility);
188 } // namespace
189 
190 /// Parse an enum from the keyword, or default to the provided default value.
191 /// The return type is the enum type by default, unless overridden with the
192 /// second template argument.
193 template <typename EnumTy, typename RetTy = EnumTy>
194 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
195                                       OperationState &result,
196                                       EnumTy defaultValue) {
197   SmallVector<StringRef, 10> names;
198   for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
199     names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
200 
201   int index = parseOptionalKeywordAlternative(parser, names);
202   if (index == -1)
203     return static_cast<RetTy>(defaultValue);
204   return static_cast<RetTy>(index);
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // Operand bundle helpers.
209 //===----------------------------------------------------------------------===//
210 
211 static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands,
212                              TypeRange operandTypes, StringRef tag) {
213   p.printString(tag);
214   p << "(";
215 
216   if (!operands.empty()) {
217     p.printOperands(operands);
218     p << " : ";
219     llvm::interleaveComma(operandTypes, p);
220   }
221 
222   p << ")";
223 }
224 
225 static void printOpBundles(OpAsmPrinter &p, Operation *op,
226                            OperandRangeRange opBundleOperands,
227                            TypeRangeRange opBundleOperandTypes,
228                            std::optional<ArrayAttr> opBundleTags) {
229   if (opBundleOperands.empty())
230     return;
231   assert(opBundleTags && "expect operand bundle tags");
232 
233   p << "[";
234   llvm::interleaveComma(
235       llvm::zip(opBundleOperands, opBundleOperandTypes, *opBundleTags), p,
236       [&p](auto bundle) {
237         auto bundleTag = cast<StringAttr>(std::get<2>(bundle)).getValue();
238         printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle),
239                          bundleTag);
240       });
241   p << "]";
242 }
243 
244 static ParseResult parseOneOpBundle(
245     OpAsmParser &p,
246     SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
247     SmallVector<SmallVector<Type>> &opBundleOperandTypes,
248     SmallVector<Attribute> &opBundleTags) {
249   SMLoc currentParserLoc = p.getCurrentLocation();
250   SmallVector<OpAsmParser::UnresolvedOperand> operands;
251   SmallVector<Type> types;
252   std::string tag;
253 
254   if (p.parseString(&tag))
255     return p.emitError(currentParserLoc, "expect operand bundle tag");
256 
257   if (p.parseLParen())
258     return failure();
259 
260   if (p.parseOptionalRParen()) {
261     if (p.parseOperandList(operands) || p.parseColon() ||
262         p.parseTypeList(types) || p.parseRParen())
263       return failure();
264   }
265 
266   opBundleOperands.push_back(std::move(operands));
267   opBundleOperandTypes.push_back(std::move(types));
268   opBundleTags.push_back(StringAttr::get(p.getContext(), tag));
269 
270   return success();
271 }
272 
273 static std::optional<ParseResult> parseOpBundles(
274     OpAsmParser &p,
275     SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
276     SmallVector<SmallVector<Type>> &opBundleOperandTypes,
277     ArrayAttr &opBundleTags) {
278   if (p.parseOptionalLSquare())
279     return std::nullopt;
280 
281   if (succeeded(p.parseOptionalRSquare()))
282     return success();
283 
284   SmallVector<Attribute> opBundleTagAttrs;
285   auto bundleParser = [&] {
286     return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
287                             opBundleTagAttrs);
288   };
289   if (p.parseCommaSeparatedList(bundleParser))
290     return failure();
291 
292   if (p.parseRSquare())
293     return failure();
294 
295   opBundleTags = ArrayAttr::get(p.getContext(), opBundleTagAttrs);
296 
297   return success();
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // Printing, parsing, folding and builder for LLVM::CmpOp.
302 //===----------------------------------------------------------------------===//
303 
304 void ICmpOp::print(OpAsmPrinter &p) {
305   p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
306     << ", " << getOperand(1);
307   p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
308   p << " : " << getLhs().getType();
309 }
310 
311 void FCmpOp::print(OpAsmPrinter &p) {
312   p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
313     << ", " << getOperand(1);
314   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
315   p << " : " << getLhs().getType();
316 }
317 
318 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
319 //                 attribute-dict? `:` type
320 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
321 //                 attribute-dict? `:` type
322 template <typename CmpPredicateType>
323 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
324   StringAttr predicateAttr;
325   OpAsmParser::UnresolvedOperand lhs, rhs;
326   Type type;
327   SMLoc predicateLoc, trailingTypeLoc;
328   if (parser.getCurrentLocation(&predicateLoc) ||
329       parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
330       parser.parseOperand(lhs) || parser.parseComma() ||
331       parser.parseOperand(rhs) ||
332       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
333       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
334       parser.resolveOperand(lhs, type, result.operands) ||
335       parser.resolveOperand(rhs, type, result.operands))
336     return failure();
337 
338   // Replace the string attribute `predicate` with an integer attribute.
339   int64_t predicateValue = 0;
340   if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
341     std::optional<ICmpPredicate> predicate =
342         symbolizeICmpPredicate(predicateAttr.getValue());
343     if (!predicate)
344       return parser.emitError(predicateLoc)
345              << "'" << predicateAttr.getValue()
346              << "' is an incorrect value of the 'predicate' attribute";
347     predicateValue = static_cast<int64_t>(*predicate);
348   } else {
349     std::optional<FCmpPredicate> predicate =
350         symbolizeFCmpPredicate(predicateAttr.getValue());
351     if (!predicate)
352       return parser.emitError(predicateLoc)
353              << "'" << predicateAttr.getValue()
354              << "' is an incorrect value of the 'predicate' attribute";
355     predicateValue = static_cast<int64_t>(*predicate);
356   }
357 
358   result.attributes.set("predicate",
359                         parser.getBuilder().getI64IntegerAttr(predicateValue));
360 
361   // The result type is either i1 or a vector type <? x i1> if the inputs are
362   // vectors.
363   if (!isCompatibleType(type))
364     return parser.emitError(trailingTypeLoc,
365                             "expected LLVM dialect-compatible type");
366   result.addTypes(getI1SameShape(type));
367   return success();
368 }
369 
370 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
371   return parseCmpOp<ICmpPredicate>(parser, result);
372 }
373 
374 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
375   return parseCmpOp<FCmpPredicate>(parser, result);
376 }
377 
378 /// Returns a scalar or vector boolean attribute of the given type.
379 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
380   auto boolAttr = BoolAttr::get(ctx, value);
381   ShapedType shapedType = dyn_cast<ShapedType>(type);
382   if (!shapedType)
383     return boolAttr;
384   return DenseElementsAttr::get(shapedType, boolAttr);
385 }
386 
387 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
388   if (getPredicate() != ICmpPredicate::eq &&
389       getPredicate() != ICmpPredicate::ne)
390     return {};
391 
392   // cmpi(eq/ne, x, x) -> true/false
393   if (getLhs() == getRhs())
394     return getBoolAttribute(getType(), getContext(),
395                             getPredicate() == ICmpPredicate::eq);
396 
397   // cmpi(eq/ne, alloca, null) -> false/true
398   if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>())
399     return getBoolAttribute(getType(), getContext(),
400                             getPredicate() == ICmpPredicate::ne);
401 
402   // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
403   if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) {
404     Value lhs = getLhs();
405     Value rhs = getRhs();
406     getLhsMutable().assign(rhs);
407     getRhsMutable().assign(lhs);
408     return getResult();
409   }
410 
411   return {};
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // Printing, parsing and verification for LLVM::AllocaOp.
416 //===----------------------------------------------------------------------===//
417 
418 void AllocaOp::print(OpAsmPrinter &p) {
419   auto funcTy =
420       FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
421 
422   if (getInalloca())
423     p << " inalloca";
424 
425   p << ' ' << getArraySize() << " x " << getElemType();
426   if (getAlignment() && *getAlignment() != 0)
427     p.printOptionalAttrDict((*this)->getAttrs(),
428                             {kElemTypeAttrName, getInallocaAttrName()});
429   else
430     p.printOptionalAttrDict(
431         (*this)->getAttrs(),
432         {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()});
433   p << " : " << funcTy;
434 }
435 
436 // <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type
437 //                  attribute-dict? `:` type `,` type
438 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
439   OpAsmParser::UnresolvedOperand arraySize;
440   Type type, elemType;
441   SMLoc trailingTypeLoc;
442 
443   if (succeeded(parser.parseOptionalKeyword("inalloca")))
444     result.addAttribute(getInallocaAttrName(result.name),
445                         UnitAttr::get(parser.getContext()));
446 
447   if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
448       parser.parseType(elemType) ||
449       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
450       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
451     return failure();
452 
453   std::optional<NamedAttribute> alignmentAttr =
454       result.attributes.getNamed("alignment");
455   if (alignmentAttr.has_value()) {
456     auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue());
457     if (!alignmentInt)
458       return parser.emitError(parser.getNameLoc(),
459                               "expected integer alignment");
460     if (alignmentInt.getValue().isZero())
461       result.attributes.erase("alignment");
462   }
463 
464   // Extract the result type from the trailing function type.
465   auto funcType = llvm::dyn_cast<FunctionType>(type);
466   if (!funcType || funcType.getNumInputs() != 1 ||
467       funcType.getNumResults() != 1)
468     return parser.emitError(
469         trailingTypeLoc,
470         "expected trailing function type with one argument and one result");
471 
472   if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
473     return failure();
474 
475   Type resultType = funcType.getResult(0);
476   if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType))
477     result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
478 
479   result.addTypes({funcType.getResult(0)});
480   return success();
481 }
482 
483 LogicalResult AllocaOp::verify() {
484   // Only certain target extension types can be used in 'alloca'.
485   if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType());
486       targetExtType && !targetExtType.supportsMemOps())
487     return emitOpError()
488            << "this target extension type cannot be used in alloca";
489 
490   return success();
491 }
492 
493 //===----------------------------------------------------------------------===//
494 // LLVM::BrOp
495 //===----------------------------------------------------------------------===//
496 
497 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
498   assert(index == 0 && "invalid successor index");
499   return SuccessorOperands(getDestOperandsMutable());
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // LLVM::CondBrOp
504 //===----------------------------------------------------------------------===//
505 
506 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
507   assert(index < getNumSuccessors() && "invalid successor index");
508   return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
509                                       : getFalseDestOperandsMutable());
510 }
511 
512 void CondBrOp::build(OpBuilder &builder, OperationState &result,
513                      Value condition, Block *trueDest, ValueRange trueOperands,
514                      Block *falseDest, ValueRange falseOperands,
515                      std::optional<std::pair<uint32_t, uint32_t>> weights) {
516   DenseI32ArrayAttr weightsAttr;
517   if (weights)
518     weightsAttr =
519         builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
520                                       static_cast<int32_t>(weights->second)});
521 
522   build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
523         /*loop_annotation=*/{}, trueDest, falseDest);
524 }
525 
526 //===----------------------------------------------------------------------===//
527 // LLVM::SwitchOp
528 //===----------------------------------------------------------------------===//
529 
530 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
531                      Block *defaultDestination, ValueRange defaultOperands,
532                      DenseIntElementsAttr caseValues,
533                      BlockRange caseDestinations,
534                      ArrayRef<ValueRange> caseOperands,
535                      ArrayRef<int32_t> branchWeights) {
536   DenseI32ArrayAttr weightsAttr;
537   if (!branchWeights.empty())
538     weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
539 
540   build(builder, result, value, defaultOperands, caseOperands, caseValues,
541         weightsAttr, defaultDestination, caseDestinations);
542 }
543 
544 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
545                      Block *defaultDestination, ValueRange defaultOperands,
546                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
547                      ArrayRef<ValueRange> caseOperands,
548                      ArrayRef<int32_t> branchWeights) {
549   DenseIntElementsAttr caseValuesAttr;
550   if (!caseValues.empty()) {
551     ShapedType caseValueType = VectorType::get(
552         static_cast<int64_t>(caseValues.size()), value.getType());
553     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
554   }
555 
556   build(builder, result, value, defaultDestination, defaultOperands,
557         caseValuesAttr, caseDestinations, caseOperands, branchWeights);
558 }
559 
560 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
561                      Block *defaultDestination, ValueRange defaultOperands,
562                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
563                      ArrayRef<ValueRange> caseOperands,
564                      ArrayRef<int32_t> branchWeights) {
565   DenseIntElementsAttr caseValuesAttr;
566   if (!caseValues.empty()) {
567     ShapedType caseValueType = VectorType::get(
568         static_cast<int64_t>(caseValues.size()), value.getType());
569     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
570   }
571 
572   build(builder, result, value, defaultDestination, defaultOperands,
573         caseValuesAttr, caseDestinations, caseOperands, branchWeights);
574 }
575 
576 /// <cases> ::= `[` (case (`,` case )* )? `]`
577 /// <case>  ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
578 static ParseResult parseSwitchOpCases(
579     OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
580     SmallVectorImpl<Block *> &caseDestinations,
581     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
582     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
583   if (failed(parser.parseLSquare()))
584     return failure();
585   if (succeeded(parser.parseOptionalRSquare()))
586     return success();
587   SmallVector<APInt> values;
588   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
589   auto parseCase = [&]() {
590     int64_t value = 0;
591     if (failed(parser.parseInteger(value)))
592       return failure();
593     values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
594 
595     Block *destination;
596     SmallVector<OpAsmParser::UnresolvedOperand> operands;
597     SmallVector<Type> operandTypes;
598     if (parser.parseColon() || parser.parseSuccessor(destination))
599       return failure();
600     if (!parser.parseOptionalLParen()) {
601       if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
602                                   /*allowResultNumber=*/false) ||
603           parser.parseColonTypeList(operandTypes) || parser.parseRParen())
604         return failure();
605     }
606     caseDestinations.push_back(destination);
607     caseOperands.emplace_back(operands);
608     caseOperandTypes.emplace_back(operandTypes);
609     return success();
610   };
611   if (failed(parser.parseCommaSeparatedList(parseCase)))
612     return failure();
613 
614   ShapedType caseValueType =
615       VectorType::get(static_cast<int64_t>(values.size()), flagType);
616   caseValues = DenseIntElementsAttr::get(caseValueType, values);
617   return parser.parseRSquare();
618 }
619 
620 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
621                                DenseIntElementsAttr caseValues,
622                                SuccessorRange caseDestinations,
623                                OperandRangeRange caseOperands,
624                                const TypeRangeRange &caseOperandTypes) {
625   p << '[';
626   p.printNewline();
627   if (!caseValues) {
628     p << ']';
629     return;
630   }
631 
632   size_t index = 0;
633   llvm::interleave(
634       llvm::zip(caseValues, caseDestinations),
635       [&](auto i) {
636         p << "  ";
637         p << std::get<0>(i).getLimitedValue();
638         p << ": ";
639         p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
640       },
641       [&] {
642         p << ',';
643         p.printNewline();
644       });
645   p.printNewline();
646   p << ']';
647 }
648 
649 LogicalResult SwitchOp::verify() {
650   if ((!getCaseValues() && !getCaseDestinations().empty()) ||
651       (getCaseValues() &&
652        getCaseValues()->size() !=
653            static_cast<int64_t>(getCaseDestinations().size())))
654     return emitOpError("expects number of case values to match number of "
655                        "case destinations");
656   if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
657     return emitError("expects number of branch weights to match number of "
658                      "successors: ")
659            << getBranchWeights()->size() << " vs " << getNumSuccessors();
660   if (getCaseValues() &&
661       getValue().getType() != getCaseValues()->getElementType())
662     return emitError("expects case value type to match condition value type");
663   return success();
664 }
665 
666 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
667   assert(index < getNumSuccessors() && "invalid successor index");
668   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
669                                       : getCaseOperandsMutable(index - 1));
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // Code for LLVM::GEPOp.
674 //===----------------------------------------------------------------------===//
675 
676 constexpr int32_t GEPOp::kDynamicIndex;
677 
678 GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
679   return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
680                                        getDynamicIndices());
681 }
682 
683 /// Returns the elemental type of any LLVM-compatible vector type or self.
684 static Type extractVectorElementType(Type type) {
685   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
686     return vectorType.getElementType();
687   if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type))
688     return scalableVectorType.getElementType();
689   if (auto fixedVectorType = llvm::dyn_cast<LLVMFixedVectorType>(type))
690     return fixedVectorType.getElementType();
691   return type;
692 }
693 
694 /// Destructures the 'indices' parameter into 'rawConstantIndices' and
695 /// 'dynamicIndices', encoding the former in the process. In the process,
696 /// dynamic indices which are used to index into a structure type are converted
697 /// to constant indices when possible. To do this, the GEPs element type should
698 /// be passed as first parameter.
699 static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
700                                SmallVectorImpl<int32_t> &rawConstantIndices,
701                                SmallVectorImpl<Value> &dynamicIndices) {
702   for (const GEPArg &iter : indices) {
703     // If the thing we are currently indexing into is a struct we must turn
704     // any integer constants into constant indices. If this is not possible
705     // we don't do anything here. The verifier will catch it and emit a proper
706     // error. All other canonicalization is done in the fold method.
707     bool requiresConst = !rawConstantIndices.empty() &&
708                          isa_and_nonnull<LLVMStructType>(currType);
709     if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
710       APInt intC;
711       if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
712           intC.isSignedIntN(kGEPConstantBitWidth)) {
713         rawConstantIndices.push_back(intC.getSExtValue());
714       } else {
715         rawConstantIndices.push_back(GEPOp::kDynamicIndex);
716         dynamicIndices.push_back(val);
717       }
718     } else {
719       rawConstantIndices.push_back(cast<GEPConstantIndex>(iter));
720     }
721 
722     // Skip for very first iteration of this loop. First index does not index
723     // within the aggregates, but is just a pointer offset.
724     if (rawConstantIndices.size() == 1 || !currType)
725       continue;
726 
727     currType =
728         TypeSwitch<Type, Type>(currType)
729             .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
730                   LLVMArrayType>([](auto containerType) {
731               return containerType.getElementType();
732             })
733             .Case([&](LLVMStructType structType) -> Type {
734               int64_t memberIndex = rawConstantIndices.back();
735               if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
736                                           structType.getBody().size())
737                 return structType.getBody()[memberIndex];
738               return nullptr;
739             })
740             .Default(Type(nullptr));
741   }
742 }
743 
744 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
745                   Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
746                   bool inbounds, ArrayRef<NamedAttribute> attributes) {
747   SmallVector<int32_t> rawConstantIndices;
748   SmallVector<Value> dynamicIndices;
749   destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
750 
751   result.addTypes(resultType);
752   result.addAttributes(attributes);
753   result.addAttribute(getRawConstantIndicesAttrName(result.name),
754                       builder.getDenseI32ArrayAttr(rawConstantIndices));
755   if (inbounds) {
756     result.addAttribute(getInboundsAttrName(result.name),
757                         builder.getUnitAttr());
758   }
759   result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
760   result.addOperands(basePtr);
761   result.addOperands(dynamicIndices);
762 }
763 
764 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
765                   Type elementType, Value basePtr, ValueRange indices,
766                   bool inbounds, ArrayRef<NamedAttribute> attributes) {
767   build(builder, result, resultType, elementType, basePtr,
768         SmallVector<GEPArg>(indices), inbounds, attributes);
769 }
770 
771 static ParseResult
772 parseGEPIndices(OpAsmParser &parser,
773                 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
774                 DenseI32ArrayAttr &rawConstantIndices) {
775   SmallVector<int32_t> constantIndices;
776 
777   auto idxParser = [&]() -> ParseResult {
778     int32_t constantIndex;
779     OptionalParseResult parsedInteger =
780         parser.parseOptionalInteger(constantIndex);
781     if (parsedInteger.has_value()) {
782       if (failed(parsedInteger.value()))
783         return failure();
784       constantIndices.push_back(constantIndex);
785       return success();
786     }
787 
788     constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
789     return parser.parseOperand(indices.emplace_back());
790   };
791   if (parser.parseCommaSeparatedList(idxParser))
792     return failure();
793 
794   rawConstantIndices =
795       DenseI32ArrayAttr::get(parser.getContext(), constantIndices);
796   return success();
797 }
798 
799 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
800                             OperandRange indices,
801                             DenseI32ArrayAttr rawConstantIndices) {
802   llvm::interleaveComma(
803       GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
804       [&](PointerUnion<IntegerAttr, Value> cst) {
805         if (Value val = llvm::dyn_cast_if_present<Value>(cst))
806           printer.printOperand(val);
807         else
808           printer << cast<IntegerAttr>(cst).getInt();
809       });
810 }
811 
812 /// For the given `indices`, check if they comply with `baseGEPType`,
813 /// especially check against LLVMStructTypes nested within.
814 static LogicalResult
815 verifyStructIndices(Type baseGEPType, unsigned indexPos,
816                     GEPIndicesAdaptor<ValueRange> indices,
817                     function_ref<InFlightDiagnostic()> emitOpError) {
818   if (indexPos >= indices.size())
819     // Stop searching
820     return success();
821 
822   return TypeSwitch<Type, LogicalResult>(baseGEPType)
823       .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
824         auto attr = dyn_cast<IntegerAttr>(indices[indexPos]);
825         if (!attr)
826           return emitOpError() << "expected index " << indexPos
827                                << " indexing a struct to be constant";
828 
829         int32_t gepIndex = attr.getInt();
830         ArrayRef<Type> elementTypes = structType.getBody();
831         if (gepIndex < 0 ||
832             static_cast<size_t>(gepIndex) >= elementTypes.size())
833           return emitOpError() << "index " << indexPos
834                                << " indexing a struct is out of bounds";
835 
836         // Instead of recursively going into every children types, we only
837         // dive into the one indexed by gepIndex.
838         return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
839                                    indices, emitOpError);
840       })
841       .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
842             LLVMArrayType>([&](auto containerType) -> LogicalResult {
843         return verifyStructIndices(containerType.getElementType(), indexPos + 1,
844                                    indices, emitOpError);
845       })
846       .Default([&](auto otherType) -> LogicalResult {
847         return emitOpError()
848                << "type " << otherType << " cannot be indexed (index #"
849                << indexPos << ")";
850       });
851 }
852 
853 /// Driver function around `verifyStructIndices`.
854 static LogicalResult
855 verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices,
856                     function_ref<InFlightDiagnostic()> emitOpError) {
857   return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError);
858 }
859 
860 LogicalResult LLVM::GEPOp::verify() {
861   if (static_cast<size_t>(
862           llvm::count(getRawConstantIndices(), kDynamicIndex)) !=
863       getDynamicIndices().size())
864     return emitOpError("expected as many dynamic indices as specified in '")
865            << getRawConstantIndicesAttrName().getValue() << "'";
866 
867   return verifyStructIndices(getElemType(), getIndices(),
868                              [&] { return emitOpError(); });
869 }
870 
871 //===----------------------------------------------------------------------===//
872 // LoadOp
873 //===----------------------------------------------------------------------===//
874 
875 void LoadOp::getEffects(
876     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
877         &effects) {
878   effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable());
879   // Volatile operations can have target-specific read-write effects on
880   // memory besides the one referred to by the pointer operand.
881   // Similarly, atomic operations that are monotonic or stricter cause
882   // synchronization that from a language point-of-view, are arbitrary
883   // read-writes into memory.
884   if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
885                          getOrdering() != AtomicOrdering::unordered)) {
886     effects.emplace_back(MemoryEffects::Write::get());
887     effects.emplace_back(MemoryEffects::Read::get());
888   }
889 }
890 
891 /// Returns true if the given type is supported by atomic operations. All
892 /// integer, float, and pointer types with a power-of-two bitsize and a minimal
893 /// size of 8 bits are supported.
894 static bool isTypeCompatibleWithAtomicOp(Type type,
895                                          const DataLayout &dataLayout) {
896   if (!isa<IntegerType, LLVMPointerType>(type))
897     if (!isCompatibleFloatingPointType(type))
898       return false;
899 
900   llvm::TypeSize bitWidth = dataLayout.getTypeSizeInBits(type);
901   if (bitWidth.isScalable())
902     return false;
903   // Needs to be at least 8 bits and a power of two.
904   return bitWidth >= 8 && (bitWidth & (bitWidth - 1)) == 0;
905 }
906 
907 /// Verifies the attributes and the type of atomic memory access operations.
908 template <typename OpTy>
909 LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType,
910                                 ArrayRef<AtomicOrdering> unsupportedOrderings) {
911   if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
912     DataLayout dataLayout = DataLayout::closest(memOp);
913     if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout))
914       return memOp.emitOpError("unsupported type ")
915              << valueType << " for atomic access";
916     if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
917       return memOp.emitOpError("unsupported ordering '")
918              << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
919     if (!memOp.getAlignment())
920       return memOp.emitOpError("expected alignment for atomic access");
921     return success();
922   }
923   if (memOp.getSyncscope())
924     return memOp.emitOpError(
925         "expected syncscope to be null for non-atomic access");
926   return success();
927 }
928 
929 LogicalResult LoadOp::verify() {
930   Type valueType = getResult().getType();
931   return verifyAtomicMemOp(*this, valueType,
932                            {AtomicOrdering::release, AtomicOrdering::acq_rel});
933 }
934 
935 void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
936                    Value addr, unsigned alignment, bool isVolatile,
937                    bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
938                    AtomicOrdering ordering, StringRef syncscope) {
939   build(builder, state, type, addr,
940         alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
941         isNonTemporal, isInvariant, isInvariantGroup, ordering,
942         syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
943         /*access_groups=*/nullptr,
944         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
945         /*tbaa=*/nullptr);
946 }
947 
948 //===----------------------------------------------------------------------===//
949 // StoreOp
950 //===----------------------------------------------------------------------===//
951 
952 void StoreOp::getEffects(
953     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
954         &effects) {
955   effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable());
956   // Volatile operations can have target-specific read-write effects on
957   // memory besides the one referred to by the pointer operand.
958   // Similarly, atomic operations that are monotonic or stricter cause
959   // synchronization that from a language point-of-view, are arbitrary
960   // read-writes into memory.
961   if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
962                          getOrdering() != AtomicOrdering::unordered)) {
963     effects.emplace_back(MemoryEffects::Write::get());
964     effects.emplace_back(MemoryEffects::Read::get());
965   }
966 }
967 
968 LogicalResult StoreOp::verify() {
969   Type valueType = getValue().getType();
970   return verifyAtomicMemOp(*this, valueType,
971                            {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
972 }
973 
974 void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
975                     Value addr, unsigned alignment, bool isVolatile,
976                     bool isNonTemporal, bool isInvariantGroup,
977                     AtomicOrdering ordering, StringRef syncscope) {
978   build(builder, state, value, addr,
979         alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
980         isNonTemporal, isInvariantGroup, ordering,
981         syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
982         /*access_groups=*/nullptr,
983         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
984 }
985 
986 //===----------------------------------------------------------------------===//
987 // CallOp
988 //===----------------------------------------------------------------------===//
989 
990 /// Gets the MLIR Op-like result types of a LLVMFunctionType.
991 static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
992   SmallVector<Type, 1> results;
993   Type resultType = calleeType.getReturnType();
994   if (!isa<LLVM::LLVMVoidType>(resultType))
995     results.push_back(resultType);
996   return results;
997 }
998 
999 /// Gets the variadic callee type for a LLVMFunctionType.
1000 static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
1001   return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
1002 }
1003 
1004 /// Constructs a LLVMFunctionType from MLIR `results` and `args`.
1005 static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
1006                                         ValueRange args) {
1007   Type resultType;
1008   if (results.empty())
1009     resultType = LLVMVoidType::get(context);
1010   else
1011     resultType = results.front();
1012   return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()),
1013                                /*isVarArg=*/false);
1014 }
1015 
1016 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
1017                    StringRef callee, ValueRange args) {
1018   build(builder, state, results, builder.getStringAttr(callee), args);
1019 }
1020 
1021 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
1022                    StringAttr callee, ValueRange args) {
1023   build(builder, state, results, SymbolRefAttr::get(callee), args);
1024 }
1025 
1026 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
1027                    FlatSymbolRefAttr callee, ValueRange args) {
1028   assert(callee && "expected non-null callee in direct call builder");
1029   build(builder, state, results,
1030         /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
1031         /*branch_weights=*/nullptr,
1032         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
1033         /*memory_effects=*/nullptr,
1034         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1035         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1036         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1037         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
1038 }
1039 
1040 void CallOp::build(OpBuilder &builder, OperationState &state,
1041                    LLVMFunctionType calleeType, StringRef callee,
1042                    ValueRange args) {
1043   build(builder, state, calleeType, builder.getStringAttr(callee), args);
1044 }
1045 
1046 void CallOp::build(OpBuilder &builder, OperationState &state,
1047                    LLVMFunctionType calleeType, StringAttr callee,
1048                    ValueRange args) {
1049   build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
1050 }
1051 
1052 void CallOp::build(OpBuilder &builder, OperationState &state,
1053                    LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1054                    ValueRange args) {
1055   build(builder, state, getCallOpResultTypes(calleeType),
1056         getCallOpVarCalleeType(calleeType), callee, args,
1057         /*fastmathFlags=*/nullptr,
1058         /*branch_weights=*/nullptr, /*CConv=*/nullptr,
1059         /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1060         /*convergent=*/nullptr,
1061         /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1062         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1063         /*access_groups=*/nullptr,
1064         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
1065 }
1066 
1067 void CallOp::build(OpBuilder &builder, OperationState &state,
1068                    LLVMFunctionType calleeType, ValueRange args) {
1069   build(builder, state, getCallOpResultTypes(calleeType),
1070         getCallOpVarCalleeType(calleeType),
1071         /*callee=*/nullptr, args,
1072         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
1073         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1074         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1075         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1076         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1077         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
1078 }
1079 
1080 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1081                    ValueRange args) {
1082   auto calleeType = func.getFunctionType();
1083   build(builder, state, getCallOpResultTypes(calleeType),
1084         getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
1085         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
1086         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1087         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1088         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1089         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1090         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
1091 }
1092 
1093 CallInterfaceCallable CallOp::getCallableForCallee() {
1094   // Direct call.
1095   if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1096     return calleeAttr;
1097   // Indirect call, callee Value is the first operand.
1098   return getOperand(0);
1099 }
1100 
1101 void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1102   // Direct call.
1103   if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1104     auto symRef = cast<SymbolRefAttr>(callee);
1105     return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1106   }
1107   // Indirect call, callee Value is the first operand.
1108   return setOperand(0, cast<Value>(callee));
1109 }
1110 
1111 Operation::operand_range CallOp::getArgOperands() {
1112   return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1113 }
1114 
1115 MutableOperandRange CallOp::getArgOperandsMutable() {
1116   return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1117                              getCalleeOperands().size());
1118 }
1119 
1120 /// Verify that an inlinable callsite of a debug-info-bearing function in a
1121 /// debug-info-bearing function has a debug location attached to it. This
1122 /// mirrors an LLVM IR verifier.
1123 static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
1124   if (callee.isExternal())
1125     return success();
1126   auto parentFunc = callOp->getParentOfType<FunctionOpInterface>();
1127   if (!parentFunc)
1128     return success();
1129 
1130   auto hasSubprogram = [](Operation *op) {
1131     return op->getLoc()
1132                ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() !=
1133            nullptr;
1134   };
1135   if (!hasSubprogram(parentFunc) || !hasSubprogram(callee))
1136     return success();
1137   bool containsLoc = !isa<UnknownLoc>(callOp->getLoc());
1138   if (!containsLoc)
1139     return callOp.emitError()
1140            << "inlinable function call in a function with a DISubprogram "
1141               "location must have a debug location";
1142   return success();
1143 }
1144 
1145 /// Verify that the parameter and return types of the variadic callee type match
1146 /// the `callOp` argument and result types.
1147 template <typename OpTy>
1148 LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
1149   std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
1150   if (!varCalleeType)
1151     return success();
1152 
1153   // Verify the variadic callee type is a variadic function type.
1154   if (!varCalleeType->isVarArg())
1155     return callOp.emitOpError(
1156         "expected var_callee_type to be a variadic function type");
1157 
1158   // Verify the variadic callee type has at most as many parameters as the call
1159   // has argument operands.
1160   if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
1161     return callOp.emitOpError("expected var_callee_type to have at most ")
1162            << callOp.getArgOperands().size() << " parameters";
1163 
1164   // Verify the variadic callee type matches the call argument types.
1165   for (auto [paramType, operand] :
1166        llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
1167     if (paramType != operand.getType())
1168       return callOp.emitOpError()
1169              << "var_callee_type parameter type mismatch: " << paramType
1170              << " != " << operand.getType();
1171 
1172   // Verify the variadic callee type matches the call result type.
1173   if (!callOp.getNumResults()) {
1174     if (!isa<LLVMVoidType>(varCalleeType->getReturnType()))
1175       return callOp.emitOpError("expected var_callee_type to return void");
1176   } else {
1177     if (callOp.getResult().getType() != varCalleeType->getReturnType())
1178       return callOp.emitOpError("var_callee_type return type mismatch: ")
1179              << varCalleeType->getReturnType()
1180              << " != " << callOp.getResult().getType();
1181   }
1182   return success();
1183 }
1184 
1185 template <typename OpType>
1186 static LogicalResult verifyOperandBundles(OpType &op) {
1187   OperandRangeRange opBundleOperands = op.getOpBundleOperands();
1188   std::optional<ArrayAttr> opBundleTags = op.getOpBundleTags();
1189 
1190   auto isStringAttr = [](Attribute tagAttr) {
1191     return isa<StringAttr>(tagAttr);
1192   };
1193   if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr))
1194     return op.emitError("operand bundle tag must be a StringAttr");
1195 
1196   size_t numOpBundles = opBundleOperands.size();
1197   size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0;
1198   if (numOpBundles != numOpBundleTags)
1199     return op.emitError("expected ")
1200            << numOpBundles << " operand bundle tags, but actually got "
1201            << numOpBundleTags;
1202 
1203   return success();
1204 }
1205 
1206 LogicalResult CallOp::verify() { return verifyOperandBundles(*this); }
1207 
1208 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1209   if (failed(verifyCallOpVarCalleeType(*this)))
1210     return failure();
1211 
1212   // Type for the callee, we'll get it differently depending if it is a direct
1213   // or indirect call.
1214   Type fnType;
1215 
1216   bool isIndirect = false;
1217 
1218   // If this is an indirect call, the callee attribute is missing.
1219   FlatSymbolRefAttr calleeName = getCalleeAttr();
1220   if (!calleeName) {
1221     isIndirect = true;
1222     if (!getNumOperands())
1223       return emitOpError(
1224           "must have either a `callee` attribute or at least an operand");
1225     auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType());
1226     if (!ptrType)
1227       return emitOpError("indirect call expects a pointer as callee: ")
1228              << getOperand(0).getType();
1229 
1230     return success();
1231   } else {
1232     Operation *callee =
1233         symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr());
1234     if (!callee)
1235       return emitOpError()
1236              << "'" << calleeName.getValue()
1237              << "' does not reference a symbol in the current scope";
1238     auto fn = dyn_cast<LLVMFuncOp>(callee);
1239     if (!fn)
1240       return emitOpError() << "'" << calleeName.getValue()
1241                            << "' does not reference a valid LLVM function";
1242 
1243     if (failed(verifyCallOpDebugInfo(*this, fn)))
1244       return failure();
1245     fnType = fn.getFunctionType();
1246   }
1247 
1248   LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
1249   if (!funcType)
1250     return emitOpError("callee does not have a functional type: ") << fnType;
1251 
1252   if (funcType.isVarArg() && !getVarCalleeType())
1253     return emitOpError() << "missing var_callee_type attribute for vararg call";
1254 
1255   // Verify that the operand and result types match the callee.
1256 
1257   if (!funcType.isVarArg() &&
1258       funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
1259     return emitOpError() << "incorrect number of operands ("
1260                          << (getCalleeOperands().size() - isIndirect)
1261                          << ") for callee (expecting: "
1262                          << funcType.getNumParams() << ")";
1263 
1264   if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
1265     return emitOpError() << "incorrect number of operands ("
1266                          << (getCalleeOperands().size() - isIndirect)
1267                          << ") for varargs callee (expecting at least: "
1268                          << funcType.getNumParams() << ")";
1269 
1270   for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1271     if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1272       return emitOpError() << "operand type mismatch for operand " << i << ": "
1273                            << getOperand(i + isIndirect).getType()
1274                            << " != " << funcType.getParamType(i);
1275 
1276   if (getNumResults() == 0 &&
1277       !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1278     return emitOpError() << "expected function call to produce a value";
1279 
1280   if (getNumResults() != 0 &&
1281       llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1282     return emitOpError()
1283            << "calling function with void result must not produce values";
1284 
1285   if (getNumResults() > 1)
1286     return emitOpError()
1287            << "expected LLVM function call to produce 0 or 1 result";
1288 
1289   if (getNumResults() && getResult().getType() != funcType.getReturnType())
1290     return emitOpError() << "result type mismatch: " << getResult().getType()
1291                          << " != " << funcType.getReturnType();
1292 
1293   return success();
1294 }
1295 
1296 void CallOp::print(OpAsmPrinter &p) {
1297   auto callee = getCallee();
1298   bool isDirect = callee.has_value();
1299 
1300   p << ' ';
1301 
1302   // Print calling convention.
1303   if (getCConv() != LLVM::CConv::C)
1304     p << stringifyCConv(getCConv()) << ' ';
1305 
1306   if (getTailCallKind() != LLVM::TailCallKind::None)
1307     p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
1308 
1309   // Print the direct callee if present as a function attribute, or an indirect
1310   // callee (first operand) otherwise.
1311   if (isDirect)
1312     p.printSymbolName(callee.value());
1313   else
1314     p << getOperand(0);
1315 
1316   auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1);
1317   p << '(' << args << ')';
1318 
1319   // Print the variadic callee type if the call is variadic.
1320   if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1321     p << " vararg(" << *varCalleeType << ")";
1322 
1323   if (!getOpBundleOperands().empty()) {
1324     p << " ";
1325     printOpBundles(p, *this, getOpBundleOperands(),
1326                    getOpBundleOperands().getTypes(), getOpBundleTags());
1327   }
1328 
1329   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1330                           {getCalleeAttrName(), getTailCallKindAttrName(),
1331                            getVarCalleeTypeAttrName(), getCConvAttrName(),
1332                            getOperandSegmentSizesAttrName(),
1333                            getOpBundleSizesAttrName(),
1334                            getOpBundleTagsAttrName()});
1335 
1336   p << " : ";
1337   if (!isDirect)
1338     p << getOperand(0).getType() << ", ";
1339 
1340   // Reconstruct the function MLIR function type from operand and result types.
1341   p.printFunctionalType(args.getTypes(), getResultTypes());
1342 }
1343 
1344 /// Parses the type of a call operation and resolves the operands if the parsing
1345 /// succeeds. Returns failure otherwise.
1346 static ParseResult parseCallTypeAndResolveOperands(
1347     OpAsmParser &parser, OperationState &result, bool isDirect,
1348     ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
1349   SMLoc trailingTypesLoc = parser.getCurrentLocation();
1350   SmallVector<Type> types;
1351   if (parser.parseColonTypeList(types))
1352     return failure();
1353 
1354   if (isDirect && types.size() != 1)
1355     return parser.emitError(trailingTypesLoc,
1356                             "expected direct call to have 1 trailing type");
1357   if (!isDirect && types.size() != 2)
1358     return parser.emitError(trailingTypesLoc,
1359                             "expected indirect call to have 2 trailing types");
1360 
1361   auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
1362   if (!funcType)
1363     return parser.emitError(trailingTypesLoc,
1364                             "expected trailing function type");
1365   if (funcType.getNumResults() > 1)
1366     return parser.emitError(trailingTypesLoc,
1367                             "expected function with 0 or 1 result");
1368   if (funcType.getNumResults() == 1 &&
1369       llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
1370     return parser.emitError(trailingTypesLoc,
1371                             "expected a non-void result type");
1372 
1373   // The head element of the types list matches the callee type for
1374   // indirect calls, while the types list is emtpy for direct calls.
1375   // Append the function input types to resolve the call operation
1376   // operands.
1377   llvm::append_range(types, funcType.getInputs());
1378   if (parser.resolveOperands(operands, types, parser.getNameLoc(),
1379                              result.operands))
1380     return failure();
1381   if (funcType.getNumResults() != 0)
1382     result.addTypes(funcType.getResults());
1383 
1384   return success();
1385 }
1386 
1387 /// Parses an optional function pointer operand before the call argument list
1388 /// for indirect calls, or stops parsing at the function identifier otherwise.
1389 static ParseResult parseOptionalCallFuncPtr(
1390     OpAsmParser &parser,
1391     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands) {
1392   OpAsmParser::UnresolvedOperand funcPtrOperand;
1393   OptionalParseResult parseResult = parser.parseOptionalOperand(funcPtrOperand);
1394   if (parseResult.has_value()) {
1395     if (failed(*parseResult))
1396       return *parseResult;
1397     operands.push_back(funcPtrOperand);
1398   }
1399   return success();
1400 }
1401 
1402 static ParseResult resolveOpBundleOperands(
1403     OpAsmParser &parser, SMLoc loc, OperationState &state,
1404     ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands,
1405     ArrayRef<SmallVector<Type>> opBundleOperandTypes,
1406     StringAttr opBundleSizesAttrName) {
1407   unsigned opBundleIndex = 0;
1408   for (const auto &[operands, types] :
1409        llvm::zip_equal(opBundleOperands, opBundleOperandTypes)) {
1410     if (operands.size() != types.size())
1411       return parser.emitError(loc, "expected ")
1412              << operands.size()
1413              << " types for operand bundle operands for operand bundle #"
1414              << opBundleIndex << ", but actually got " << types.size();
1415     if (parser.resolveOperands(operands, types, loc, state.operands))
1416       return failure();
1417   }
1418 
1419   SmallVector<int32_t> opBundleSizes;
1420   opBundleSizes.reserve(opBundleOperands.size());
1421   for (const auto &operands : opBundleOperands)
1422     opBundleSizes.push_back(operands.size());
1423 
1424   state.addAttribute(
1425       opBundleSizesAttrName,
1426       DenseI32ArrayAttr::get(parser.getContext(), opBundleSizes));
1427 
1428   return success();
1429 }
1430 
1431 // <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
1432 //                             `(` ssa-use-list `)`
1433 //                             ( `vararg(` var-callee-type `)` )?
1434 //                             ( `[` op-bundles-list `]` )?
1435 //                             attribute-dict? `:` (type `,`)? function-type
1436 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1437   SymbolRefAttr funcAttr;
1438   TypeAttr varCalleeType;
1439   SmallVector<OpAsmParser::UnresolvedOperand> operands;
1440   SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
1441   SmallVector<SmallVector<Type>> opBundleOperandTypes;
1442   ArrayAttr opBundleTags;
1443 
1444   // Default to C Calling Convention if no keyword is provided.
1445   result.addAttribute(
1446       getCConvAttrName(result.name),
1447       CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1448                                               parser, result, LLVM::CConv::C)));
1449 
1450   result.addAttribute(
1451       getTailCallKindAttrName(result.name),
1452       TailCallKindAttr::get(parser.getContext(),
1453                             parseOptionalLLVMKeyword<TailCallKind>(
1454                                 parser, result, LLVM::TailCallKind::None)));
1455 
1456   // Parse a function pointer for indirect calls.
1457   if (parseOptionalCallFuncPtr(parser, operands))
1458     return failure();
1459   bool isDirect = operands.empty();
1460 
1461   // Parse a function identifier for direct calls.
1462   if (isDirect)
1463     if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1464       return failure();
1465 
1466   // Parse the function arguments.
1467   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren))
1468     return failure();
1469 
1470   bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1471   if (isVarArg) {
1472     StringAttr varCalleeTypeAttrName =
1473         CallOp::getVarCalleeTypeAttrName(result.name);
1474     if (parser.parseLParen().failed() ||
1475         parser
1476             .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1477                             result.attributes)
1478             .failed() ||
1479         parser.parseRParen().failed())
1480       return failure();
1481   }
1482 
1483   SMLoc opBundlesLoc = parser.getCurrentLocation();
1484   if (std::optional<ParseResult> result = parseOpBundles(
1485           parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1486       result && failed(*result))
1487     return failure();
1488   if (opBundleTags && !opBundleTags.empty())
1489     result.addAttribute(CallOp::getOpBundleTagsAttrName(result.name).getValue(),
1490                         opBundleTags);
1491 
1492   if (parser.parseOptionalAttrDict(result.attributes))
1493     return failure();
1494 
1495   // Parse the trailing type list and resolve the operands.
1496   if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
1497     return failure();
1498   if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1499                               opBundleOperandTypes,
1500                               getOpBundleSizesAttrName(result.name)))
1501     return failure();
1502 
1503   int32_t numOpBundleOperands = 0;
1504   for (const auto &operands : opBundleOperands)
1505     numOpBundleOperands += operands.size();
1506 
1507   result.addAttribute(
1508       CallOp::getOperandSegmentSizeAttr(),
1509       parser.getBuilder().getDenseI32ArrayAttr(
1510           {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
1511   return success();
1512 }
1513 
1514 LLVMFunctionType CallOp::getCalleeFunctionType() {
1515   if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1516     return *varCalleeType;
1517   return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1518 }
1519 
1520 ///===---------------------------------------------------------------------===//
1521 /// LLVM::InvokeOp
1522 ///===---------------------------------------------------------------------===//
1523 
1524 void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1525                      ValueRange ops, Block *normal, ValueRange normalOps,
1526                      Block *unwind, ValueRange unwindOps) {
1527   auto calleeType = func.getFunctionType();
1528   build(builder, state, getCallOpResultTypes(calleeType),
1529         getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
1530         normalOps, unwindOps, nullptr, nullptr, {}, {}, normal, unwind);
1531 }
1532 
1533 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
1534                      FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
1535                      ValueRange normalOps, Block *unwind,
1536                      ValueRange unwindOps) {
1537   build(builder, state, tys,
1538         /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
1539         nullptr, {}, {}, normal, unwind);
1540 }
1541 
1542 void InvokeOp::build(OpBuilder &builder, OperationState &state,
1543                      LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1544                      ValueRange ops, Block *normal, ValueRange normalOps,
1545                      Block *unwind, ValueRange unwindOps) {
1546   build(builder, state, getCallOpResultTypes(calleeType),
1547         getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
1548         nullptr, nullptr, {}, {}, normal, unwind);
1549 }
1550 
1551 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
1552   assert(index < getNumSuccessors() && "invalid successor index");
1553   return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
1554                                       : getUnwindDestOperandsMutable());
1555 }
1556 
1557 CallInterfaceCallable InvokeOp::getCallableForCallee() {
1558   // Direct call.
1559   if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1560     return calleeAttr;
1561   // Indirect call, callee Value is the first operand.
1562   return getOperand(0);
1563 }
1564 
1565 void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1566   // Direct call.
1567   if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1568     auto symRef = cast<SymbolRefAttr>(callee);
1569     return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1570   }
1571   // Indirect call, callee Value is the first operand.
1572   return setOperand(0, cast<Value>(callee));
1573 }
1574 
1575 Operation::operand_range InvokeOp::getArgOperands() {
1576   return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1577 }
1578 
1579 MutableOperandRange InvokeOp::getArgOperandsMutable() {
1580   return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1581                              getCalleeOperands().size());
1582 }
1583 
1584 LogicalResult InvokeOp::verify() {
1585   if (failed(verifyCallOpVarCalleeType(*this)))
1586     return failure();
1587 
1588   Block *unwindDest = getUnwindDest();
1589   if (unwindDest->empty())
1590     return emitError("must have at least one operation in unwind destination");
1591 
1592   // In unwind destination, first operation must be LandingpadOp
1593   if (!isa<LandingpadOp>(unwindDest->front()))
1594     return emitError("first operation in unwind destination should be a "
1595                      "llvm.landingpad operation");
1596 
1597   if (failed(verifyOperandBundles(*this)))
1598     return failure();
1599 
1600   return success();
1601 }
1602 
1603 void InvokeOp::print(OpAsmPrinter &p) {
1604   auto callee = getCallee();
1605   bool isDirect = callee.has_value();
1606 
1607   p << ' ';
1608 
1609   // Print calling convention.
1610   if (getCConv() != LLVM::CConv::C)
1611     p << stringifyCConv(getCConv()) << ' ';
1612 
1613   // Either function name or pointer
1614   if (isDirect)
1615     p.printSymbolName(callee.value());
1616   else
1617     p << getOperand(0);
1618 
1619   p << '(' << getCalleeOperands().drop_front(isDirect ? 0 : 1) << ')';
1620   p << " to ";
1621   p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
1622   p << " unwind ";
1623   p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
1624 
1625   // Print the variadic callee type if the invoke is variadic.
1626   if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1627     p << " vararg(" << *varCalleeType << ")";
1628 
1629   if (!getOpBundleOperands().empty()) {
1630     p << " ";
1631     printOpBundles(p, *this, getOpBundleOperands(),
1632                    getOpBundleOperands().getTypes(), getOpBundleTags());
1633   }
1634 
1635   p.printOptionalAttrDict((*this)->getAttrs(),
1636                           {getCalleeAttrName(), getOperandSegmentSizeAttr(),
1637                            getCConvAttrName(), getVarCalleeTypeAttrName(),
1638                            getOpBundleSizesAttrName(),
1639                            getOpBundleTagsAttrName()});
1640 
1641   p << " : ";
1642   if (!isDirect)
1643     p << getOperand(0).getType() << ", ";
1644   p.printFunctionalType(
1645       llvm::drop_begin(getCalleeOperands().getTypes(), isDirect ? 0 : 1),
1646       getResultTypes());
1647 }
1648 
1649 // <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
1650 //                  `(` ssa-use-list `)`
1651 //                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
1652 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1653 //                  ( `vararg(` var-callee-type `)` )?
1654 //                  ( `[` op-bundles-list `]` )?
1655 //                  attribute-dict? `:` (type `,`)? function-type
1656 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1657   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
1658   SymbolRefAttr funcAttr;
1659   TypeAttr varCalleeType;
1660   SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
1661   SmallVector<SmallVector<Type>> opBundleOperandTypes;
1662   ArrayAttr opBundleTags;
1663   Block *normalDest, *unwindDest;
1664   SmallVector<Value, 4> normalOperands, unwindOperands;
1665   Builder &builder = parser.getBuilder();
1666 
1667   // Default to C Calling Convention if no keyword is provided.
1668   result.addAttribute(
1669       getCConvAttrName(result.name),
1670       CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1671                                               parser, result, LLVM::CConv::C)));
1672 
1673   // Parse a function pointer for indirect calls.
1674   if (parseOptionalCallFuncPtr(parser, operands))
1675     return failure();
1676   bool isDirect = operands.empty();
1677 
1678   // Parse a function identifier for direct calls.
1679   if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
1680     return failure();
1681 
1682   // Parse the function arguments.
1683   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1684       parser.parseKeyword("to") ||
1685       parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
1686       parser.parseKeyword("unwind") ||
1687       parser.parseSuccessorAndUseList(unwindDest, unwindOperands))
1688     return failure();
1689 
1690   bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1691   if (isVarArg) {
1692     StringAttr varCalleeTypeAttrName =
1693         InvokeOp::getVarCalleeTypeAttrName(result.name);
1694     if (parser.parseLParen().failed() ||
1695         parser
1696             .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1697                             result.attributes)
1698             .failed() ||
1699         parser.parseRParen().failed())
1700       return failure();
1701   }
1702 
1703   SMLoc opBundlesLoc = parser.getCurrentLocation();
1704   if (std::optional<ParseResult> result = parseOpBundles(
1705           parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1706       result && failed(*result))
1707     return failure();
1708   if (opBundleTags && !opBundleTags.empty())
1709     result.addAttribute(
1710         InvokeOp::getOpBundleTagsAttrName(result.name).getValue(),
1711         opBundleTags);
1712 
1713   if (parser.parseOptionalAttrDict(result.attributes))
1714     return failure();
1715 
1716   // Parse the trailing type list and resolve the function operands.
1717   if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
1718     return failure();
1719   if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1720                               opBundleOperandTypes,
1721                               getOpBundleSizesAttrName(result.name)))
1722     return failure();
1723 
1724   result.addSuccessors({normalDest, unwindDest});
1725   result.addOperands(normalOperands);
1726   result.addOperands(unwindOperands);
1727 
1728   int32_t numOpBundleOperands = 0;
1729   for (const auto &operands : opBundleOperands)
1730     numOpBundleOperands += operands.size();
1731 
1732   result.addAttribute(
1733       InvokeOp::getOperandSegmentSizeAttr(),
1734       builder.getDenseI32ArrayAttr({static_cast<int32_t>(operands.size()),
1735                                     static_cast<int32_t>(normalOperands.size()),
1736                                     static_cast<int32_t>(unwindOperands.size()),
1737                                     numOpBundleOperands}));
1738   return success();
1739 }
1740 
1741 LLVMFunctionType InvokeOp::getCalleeFunctionType() {
1742   if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1743     return *varCalleeType;
1744   return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1745 }
1746 
1747 ///===----------------------------------------------------------------------===//
1748 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
1749 ///===----------------------------------------------------------------------===//
1750 
1751 LogicalResult LandingpadOp::verify() {
1752   Value value;
1753   if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
1754     if (!func.getPersonality())
1755       return emitError(
1756           "llvm.landingpad needs to be in a function with a personality");
1757   }
1758 
1759   // Consistency of llvm.landingpad result types is checked in
1760   // LLVMFuncOp::verify().
1761 
1762   if (!getCleanup() && getOperands().empty())
1763     return emitError("landingpad instruction expects at least one clause or "
1764                      "cleanup attribute");
1765 
1766   for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
1767     value = getOperand(idx);
1768     bool isFilter = llvm::isa<LLVMArrayType>(value.getType());
1769     if (isFilter) {
1770       // FIXME: Verify filter clauses when arrays are appropriately handled
1771     } else {
1772       // catch - global addresses only.
1773       // Bitcast ops should have global addresses as their args.
1774       if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
1775         if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
1776           continue;
1777         return emitError("constant clauses expected").attachNote(bcOp.getLoc())
1778                << "global addresses expected as operand to "
1779                   "bitcast used in clauses for landingpad";
1780       }
1781       // ZeroOp and AddressOfOp allowed
1782       if (value.getDefiningOp<ZeroOp>())
1783         continue;
1784       if (value.getDefiningOp<AddressOfOp>())
1785         continue;
1786       return emitError("clause #")
1787              << idx << " is not a known constant - null, addressof, bitcast";
1788     }
1789   }
1790   return success();
1791 }
1792 
1793 void LandingpadOp::print(OpAsmPrinter &p) {
1794   p << (getCleanup() ? " cleanup " : " ");
1795 
1796   // Clauses
1797   for (auto value : getOperands()) {
1798     // Similar to llvm - if clause is an array type then it is filter
1799     // clause else catch clause
1800     bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType());
1801     p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
1802       << value.getType() << ") ";
1803   }
1804 
1805   p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
1806 
1807   p << ": " << getType();
1808 }
1809 
1810 // <operation> ::= `llvm.landingpad` `cleanup`?
1811 //                 ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
1812 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
1813   // Check for cleanup
1814   if (succeeded(parser.parseOptionalKeyword("cleanup")))
1815     result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
1816 
1817   // Parse clauses with types
1818   while (succeeded(parser.parseOptionalLParen()) &&
1819          (succeeded(parser.parseOptionalKeyword("filter")) ||
1820           succeeded(parser.parseOptionalKeyword("catch")))) {
1821     OpAsmParser::UnresolvedOperand operand;
1822     Type ty;
1823     if (parser.parseOperand(operand) || parser.parseColon() ||
1824         parser.parseType(ty) ||
1825         parser.resolveOperand(operand, ty, result.operands) ||
1826         parser.parseRParen())
1827       return failure();
1828   }
1829 
1830   Type type;
1831   if (parser.parseColon() || parser.parseType(type))
1832     return failure();
1833 
1834   result.addTypes(type);
1835   return success();
1836 }
1837 
1838 //===----------------------------------------------------------------------===//
1839 // ExtractValueOp
1840 //===----------------------------------------------------------------------===//
1841 
1842 /// Extract the type at `position` in the LLVM IR aggregate type
1843 /// `containerType`. Each element of `position` is an index into a nested
1844 /// aggregate type. Return the resulting type or emit an error.
1845 static Type getInsertExtractValueElementType(
1846     function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType,
1847     ArrayRef<int64_t> position) {
1848   Type llvmType = containerType;
1849   if (!isCompatibleType(containerType)) {
1850     emitError("expected LLVM IR Dialect type, got ") << containerType;
1851     return {};
1852   }
1853 
1854   // Infer the element type from the structure type: iteratively step inside the
1855   // type by taking the element type, indexed by the position attribute for
1856   // structures.  Check the position index before accessing, it is supposed to
1857   // be in bounds.
1858   for (int64_t idx : position) {
1859     if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) {
1860       if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) {
1861         emitError("position out of bounds: ") << idx;
1862         return {};
1863       }
1864       llvmType = arrayType.getElementType();
1865     } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) {
1866       if (idx < 0 ||
1867           static_cast<unsigned>(idx) >= structType.getBody().size()) {
1868         emitError("position out of bounds: ") << idx;
1869         return {};
1870       }
1871       llvmType = structType.getBody()[idx];
1872     } else {
1873       emitError("expected LLVM IR structure/array type, got: ") << llvmType;
1874       return {};
1875     }
1876   }
1877   return llvmType;
1878 }
1879 
1880 /// Extract the type at `position` in the wrapped LLVM IR aggregate type
1881 /// `containerType`.
1882 static Type getInsertExtractValueElementType(Type llvmType,
1883                                              ArrayRef<int64_t> position) {
1884   for (int64_t idx : position) {
1885     if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType))
1886       llvmType = structType.getBody()[idx];
1887     else
1888       llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType();
1889   }
1890   return llvmType;
1891 }
1892 
1893 OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
1894   auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
1895   OpFoldResult result = {};
1896   while (insertValueOp) {
1897     if (getPosition() == insertValueOp.getPosition())
1898       return insertValueOp.getValue();
1899     unsigned min =
1900         std::min(getPosition().size(), insertValueOp.getPosition().size());
1901     // If one is fully prefix of the other, stop propagating back as it will
1902     // miss dependencies. For instance, %3 should not fold to %f0 in the
1903     // following example:
1904     // ```
1905     //   %1 = llvm.insertvalue %f0, %0[0, 0] :
1906     //     !llvm.array<4 x !llvm.array<4 x f32>>
1907     //   %2 = llvm.insertvalue %arr, %1[0] :
1908     //     !llvm.array<4 x !llvm.array<4 x f32>>
1909     //   %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
1910     // ```
1911     if (getPosition().take_front(min) ==
1912         insertValueOp.getPosition().take_front(min))
1913       return result;
1914 
1915     // If neither a prefix, nor the exact position, we can extract out of the
1916     // value being inserted into. Moreover, we can try again if that operand
1917     // is itself an insertvalue expression.
1918     getContainerMutable().assign(insertValueOp.getContainer());
1919     result = getResult();
1920     insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
1921   }
1922   return result;
1923 }
1924 
1925 LogicalResult ExtractValueOp::verify() {
1926   auto emitError = [this](StringRef msg) { return emitOpError(msg); };
1927   Type valueType = getInsertExtractValueElementType(
1928       emitError, getContainer().getType(), getPosition());
1929   if (!valueType)
1930     return failure();
1931 
1932   if (getRes().getType() != valueType)
1933     return emitOpError() << "Type mismatch: extracting from "
1934                          << getContainer().getType() << " should produce "
1935                          << valueType << " but this op returns "
1936                          << getRes().getType();
1937   return success();
1938 }
1939 
1940 void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
1941                            Value container, ArrayRef<int64_t> position) {
1942   build(builder, state,
1943         getInsertExtractValueElementType(container.getType(), position),
1944         container, builder.getAttr<DenseI64ArrayAttr>(position));
1945 }
1946 
1947 //===----------------------------------------------------------------------===//
1948 // InsertValueOp
1949 //===----------------------------------------------------------------------===//
1950 
1951 /// Infer the value type from the container type and position.
1952 static ParseResult
1953 parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,
1954                                    Type containerType,
1955                                    DenseI64ArrayAttr position) {
1956   valueType = getInsertExtractValueElementType(
1957       [&](StringRef msg) {
1958         return parser.emitError(parser.getCurrentLocation(), msg);
1959       },
1960       containerType, position.asArrayRef());
1961   return success(!!valueType);
1962 }
1963 
1964 /// Nothing to print for an inferred type.
1965 static void printInsertExtractValueElementType(AsmPrinter &printer,
1966                                                Operation *op, Type valueType,
1967                                                Type containerType,
1968                                                DenseI64ArrayAttr position) {}
1969 
1970 LogicalResult InsertValueOp::verify() {
1971   auto emitError = [this](StringRef msg) { return emitOpError(msg); };
1972   Type valueType = getInsertExtractValueElementType(
1973       emitError, getContainer().getType(), getPosition());
1974   if (!valueType)
1975     return failure();
1976 
1977   if (getValue().getType() != valueType)
1978     return emitOpError() << "Type mismatch: cannot insert "
1979                          << getValue().getType() << " into "
1980                          << getContainer().getType();
1981 
1982   return success();
1983 }
1984 
1985 //===----------------------------------------------------------------------===//
1986 // ReturnOp
1987 //===----------------------------------------------------------------------===//
1988 
1989 LogicalResult ReturnOp::verify() {
1990   auto parent = (*this)->getParentOfType<LLVMFuncOp>();
1991   if (!parent)
1992     return success();
1993 
1994   Type expectedType = parent.getFunctionType().getReturnType();
1995   if (llvm::isa<LLVMVoidType>(expectedType)) {
1996     if (!getArg())
1997       return success();
1998     InFlightDiagnostic diag = emitOpError("expected no operands");
1999     diag.attachNote(parent->getLoc()) << "when returning from function";
2000     return diag;
2001   }
2002   if (!getArg()) {
2003     if (llvm::isa<LLVMVoidType>(expectedType))
2004       return success();
2005     InFlightDiagnostic diag = emitOpError("expected 1 operand");
2006     diag.attachNote(parent->getLoc()) << "when returning from function";
2007     return diag;
2008   }
2009   if (expectedType != getArg().getType()) {
2010     InFlightDiagnostic diag = emitOpError("mismatching result types");
2011     diag.attachNote(parent->getLoc()) << "when returning from function";
2012     return diag;
2013   }
2014   return success();
2015 }
2016 
2017 //===----------------------------------------------------------------------===//
2018 // LLVM::AddressOfOp.
2019 //===----------------------------------------------------------------------===//
2020 
2021 static Operation *parentLLVMModule(Operation *op) {
2022   Operation *module = op->getParentOp();
2023   while (module && !satisfiesLLVMModule(module))
2024     module = module->getParentOp();
2025   assert(module && "unexpected operation outside of a module");
2026   return module;
2027 }
2028 
2029 GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
2030   return dyn_cast_or_null<GlobalOp>(
2031       symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2032 }
2033 
2034 LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) {
2035   return dyn_cast_or_null<LLVMFuncOp>(
2036       symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2037 }
2038 
2039 LogicalResult
2040 AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2041   Operation *symbol =
2042       symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr());
2043 
2044   auto global = dyn_cast_or_null<GlobalOp>(symbol);
2045   auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
2046 
2047   if (!global && !function)
2048     return emitOpError(
2049         "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
2050 
2051   LLVMPointerType type = getType();
2052   if (global && global.getAddrSpace() != type.getAddressSpace())
2053     return emitOpError("pointer address space must match address space of the "
2054                        "referenced global");
2055 
2056   return success();
2057 }
2058 
2059 // AddressOfOp constant-folds to the global symbol name.
2060 OpFoldResult LLVM::AddressOfOp::fold(FoldAdaptor) {
2061   return getGlobalNameAttr();
2062 }
2063 
2064 //===----------------------------------------------------------------------===//
2065 // Verifier for LLVM::ComdatOp.
2066 //===----------------------------------------------------------------------===//
2067 
2068 void ComdatOp::build(OpBuilder &builder, OperationState &result,
2069                      StringRef symName) {
2070   result.addAttribute(getSymNameAttrName(result.name),
2071                       builder.getStringAttr(symName));
2072   Region *body = result.addRegion();
2073   body->emplaceBlock();
2074 }
2075 
2076 LogicalResult ComdatOp::verifyRegions() {
2077   Region &body = getBody();
2078   for (Operation &op : body.getOps())
2079     if (!isa<ComdatSelectorOp>(op))
2080       return op.emitError(
2081           "only comdat selector symbols can appear in a comdat region");
2082 
2083   return success();
2084 }
2085 
2086 //===----------------------------------------------------------------------===//
2087 // Builder, printer and verifier for LLVM::GlobalOp.
2088 //===----------------------------------------------------------------------===//
2089 
2090 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
2091                      bool isConstant, Linkage linkage, StringRef name,
2092                      Attribute value, uint64_t alignment, unsigned addrSpace,
2093                      bool dsoLocal, bool threadLocal, SymbolRefAttr comdat,
2094                      ArrayRef<NamedAttribute> attrs,
2095                      ArrayRef<Attribute> dbgExprs) {
2096   result.addAttribute(getSymNameAttrName(result.name),
2097                       builder.getStringAttr(name));
2098   result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type));
2099   if (isConstant)
2100     result.addAttribute(getConstantAttrName(result.name),
2101                         builder.getUnitAttr());
2102   if (value)
2103     result.addAttribute(getValueAttrName(result.name), value);
2104   if (dsoLocal)
2105     result.addAttribute(getDsoLocalAttrName(result.name),
2106                         builder.getUnitAttr());
2107   if (threadLocal)
2108     result.addAttribute(getThreadLocal_AttrName(result.name),
2109                         builder.getUnitAttr());
2110   if (comdat)
2111     result.addAttribute(getComdatAttrName(result.name), comdat);
2112 
2113   // Only add an alignment attribute if the "alignment" input
2114   // is different from 0. The value must also be a power of two, but
2115   // this is tested in GlobalOp::verify, not here.
2116   if (alignment != 0)
2117     result.addAttribute(getAlignmentAttrName(result.name),
2118                         builder.getI64IntegerAttr(alignment));
2119 
2120   result.addAttribute(getLinkageAttrName(result.name),
2121                       LinkageAttr::get(builder.getContext(), linkage));
2122   if (addrSpace != 0)
2123     result.addAttribute(getAddrSpaceAttrName(result.name),
2124                         builder.getI32IntegerAttr(addrSpace));
2125   result.attributes.append(attrs.begin(), attrs.end());
2126 
2127   if (!dbgExprs.empty())
2128     result.addAttribute(getDbgExprsAttrName(result.name),
2129                         ArrayAttr::get(builder.getContext(), dbgExprs));
2130 
2131   result.addRegion();
2132 }
2133 
2134 void GlobalOp::print(OpAsmPrinter &p) {
2135   p << ' ' << stringifyLinkage(getLinkage()) << ' ';
2136   StringRef visibility = stringifyVisibility(getVisibility_());
2137   if (!visibility.empty())
2138     p << visibility << ' ';
2139   if (getThreadLocal_())
2140     p << "thread_local ";
2141   if (auto unnamedAddr = getUnnamedAddr()) {
2142     StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2143     if (!str.empty())
2144       p << str << ' ';
2145   }
2146   if (getConstant())
2147     p << "constant ";
2148   p.printSymbolName(getSymName());
2149   p << '(';
2150   if (auto value = getValueOrNull())
2151     p.printAttribute(value);
2152   p << ')';
2153   if (auto comdat = getComdat())
2154     p << " comdat(" << *comdat << ')';
2155 
2156   // Note that the alignment attribute is printed using the
2157   // default syntax here, even though it is an inherent attribute
2158   // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
2159   p.printOptionalAttrDict((*this)->getAttrs(),
2160                           {SymbolTable::getSymbolAttrName(),
2161                            getGlobalTypeAttrName(), getConstantAttrName(),
2162                            getValueAttrName(), getLinkageAttrName(),
2163                            getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2164                            getVisibility_AttrName(), getComdatAttrName(),
2165                            getUnnamedAddrAttrName()});
2166 
2167   // Print the trailing type unless it's a string global.
2168   if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull()))
2169     return;
2170   p << " : " << getType();
2171 
2172   Region &initializer = getInitializerRegion();
2173   if (!initializer.empty()) {
2174     p << ' ';
2175     p.printRegion(initializer, /*printEntryBlockArgs=*/false);
2176   }
2177 }
2178 
2179 static LogicalResult verifyComdat(Operation *op,
2180                                   std::optional<SymbolRefAttr> attr) {
2181   if (!attr)
2182     return success();
2183 
2184   auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr);
2185   if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector))
2186     return op->emitError() << "expected comdat symbol";
2187 
2188   return success();
2189 }
2190 
2191 // operation ::= `llvm.mlir.global` linkage? visibility?
2192 //               (`unnamed_addr` | `local_unnamed_addr`)?
2193 //               `thread_local`? `constant`? `@` identifier
2194 //               `(` attribute? `)` (`comdat(` symbol-ref-id `)`)?
2195 //               attribute-list? (`:` type)? region?
2196 //
2197 // The type can be omitted for string attributes, in which case it will be
2198 // inferred from the value of the string as [strlen(value) x i8].
2199 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
2200   MLIRContext *ctx = parser.getContext();
2201   // Parse optional linkage, default to External.
2202   result.addAttribute(getLinkageAttrName(result.name),
2203                       LLVM::LinkageAttr::get(
2204                           ctx, parseOptionalLLVMKeyword<Linkage>(
2205                                    parser, result, LLVM::Linkage::External)));
2206 
2207   // Parse optional visibility, default to Default.
2208   result.addAttribute(getVisibility_AttrName(result.name),
2209                       parser.getBuilder().getI64IntegerAttr(
2210                           parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
2211                               parser, result, LLVM::Visibility::Default)));
2212 
2213   // Parse optional UnnamedAddr, default to None.
2214   result.addAttribute(getUnnamedAddrAttrName(result.name),
2215                       parser.getBuilder().getI64IntegerAttr(
2216                           parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
2217                               parser, result, LLVM::UnnamedAddr::None)));
2218 
2219   if (succeeded(parser.parseOptionalKeyword("thread_local")))
2220     result.addAttribute(getThreadLocal_AttrName(result.name),
2221                         parser.getBuilder().getUnitAttr());
2222 
2223   if (succeeded(parser.parseOptionalKeyword("constant")))
2224     result.addAttribute(getConstantAttrName(result.name),
2225                         parser.getBuilder().getUnitAttr());
2226 
2227   StringAttr name;
2228   if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
2229                              result.attributes) ||
2230       parser.parseLParen())
2231     return failure();
2232 
2233   Attribute value;
2234   if (parser.parseOptionalRParen()) {
2235     if (parser.parseAttribute(value, getValueAttrName(result.name),
2236                               result.attributes) ||
2237         parser.parseRParen())
2238       return failure();
2239   }
2240 
2241   if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2242     SymbolRefAttr comdat;
2243     if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2244         parser.parseRParen())
2245       return failure();
2246 
2247     result.addAttribute(getComdatAttrName(result.name), comdat);
2248   }
2249 
2250   SmallVector<Type, 1> types;
2251   if (parser.parseOptionalAttrDict(result.attributes) ||
2252       parser.parseOptionalColonTypeList(types))
2253     return failure();
2254 
2255   if (types.size() > 1)
2256     return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2257 
2258   Region &initRegion = *result.addRegion();
2259   if (types.empty()) {
2260     if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) {
2261       MLIRContext *context = parser.getContext();
2262       auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
2263                                                 strAttr.getValue().size());
2264       types.push_back(arrayType);
2265     } else {
2266       return parser.emitError(parser.getNameLoc(),
2267                               "type can only be omitted for string globals");
2268     }
2269   } else {
2270     OptionalParseResult parseResult =
2271         parser.parseOptionalRegion(initRegion, /*arguments=*/{},
2272                                    /*argTypes=*/{});
2273     if (parseResult.has_value() && failed(*parseResult))
2274       return failure();
2275   }
2276 
2277   result.addAttribute(getGlobalTypeAttrName(result.name),
2278                       TypeAttr::get(types[0]));
2279   return success();
2280 }
2281 
2282 static bool isZeroAttribute(Attribute value) {
2283   if (auto intValue = llvm::dyn_cast<IntegerAttr>(value))
2284     return intValue.getValue().isZero();
2285   if (auto fpValue = llvm::dyn_cast<FloatAttr>(value))
2286     return fpValue.getValue().isZero();
2287   if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(value))
2288     return isZeroAttribute(splatValue.getSplatValue<Attribute>());
2289   if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value))
2290     return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
2291   if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value))
2292     return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
2293   return false;
2294 }
2295 
2296 LogicalResult GlobalOp::verify() {
2297   bool validType = isCompatibleOuterType(getType())
2298                        ? !llvm::isa<LLVMVoidType, LLVMTokenType,
2299                                     LLVMMetadataType, LLVMLabelType>(getType())
2300                        : llvm::isa<PointerElementTypeInterface>(getType());
2301   if (!validType)
2302     return emitOpError(
2303         "expects type to be a valid element type for an LLVM global");
2304   if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
2305     return emitOpError("must appear at the module level");
2306 
2307   if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) {
2308     auto type = llvm::dyn_cast<LLVMArrayType>(getType());
2309     IntegerType elementType =
2310         type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr;
2311     if (!elementType || elementType.getWidth() != 8 ||
2312         type.getNumElements() != strAttr.getValue().size())
2313       return emitOpError(
2314           "requires an i8 array type of the length equal to that of the string "
2315           "attribute");
2316   }
2317 
2318   if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
2319     if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal))
2320       return emitOpError()
2321              << "this target extension type cannot be used in a global";
2322 
2323     if (Attribute value = getValueOrNull())
2324       return emitOpError() << "global with target extension type can only be "
2325                               "initialized with zero-initializer";
2326   }
2327 
2328   if (getLinkage() == Linkage::Common) {
2329     if (Attribute value = getValueOrNull()) {
2330       if (!isZeroAttribute(value)) {
2331         return emitOpError()
2332                << "expected zero value for '"
2333                << stringifyLinkage(Linkage::Common) << "' linkage";
2334       }
2335     }
2336   }
2337 
2338   if (getLinkage() == Linkage::Appending) {
2339     if (!llvm::isa<LLVMArrayType>(getType())) {
2340       return emitOpError() << "expected array type for '"
2341                            << stringifyLinkage(Linkage::Appending)
2342                            << "' linkage";
2343     }
2344   }
2345 
2346   if (failed(verifyComdat(*this, getComdat())))
2347     return failure();
2348 
2349   std::optional<uint64_t> alignAttr = getAlignment();
2350   if (alignAttr.has_value()) {
2351     uint64_t value = alignAttr.value();
2352     if (!llvm::isPowerOf2_64(value))
2353       return emitError() << "alignment attribute is not a power of 2";
2354   }
2355 
2356   return success();
2357 }
2358 
2359 LogicalResult GlobalOp::verifyRegions() {
2360   if (Block *b = getInitializerBlock()) {
2361     ReturnOp ret = cast<ReturnOp>(b->getTerminator());
2362     if (ret.operand_type_begin() == ret.operand_type_end())
2363       return emitOpError("initializer region cannot return void");
2364     if (*ret.operand_type_begin() != getType())
2365       return emitOpError("initializer region type ")
2366              << *ret.operand_type_begin() << " does not match global type "
2367              << getType();
2368 
2369     for (Operation &op : *b) {
2370       auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2371       if (!iface || !iface.hasNoEffect())
2372         return op.emitError()
2373                << "ops with side effects not allowed in global initializers";
2374     }
2375 
2376     if (getValueOrNull())
2377       return emitOpError("cannot have both initializer value and region");
2378   }
2379 
2380   return success();
2381 }
2382 
2383 //===----------------------------------------------------------------------===//
2384 // LLVM::GlobalCtorsOp
2385 //===----------------------------------------------------------------------===//
2386 
2387 LogicalResult
2388 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2389   for (Attribute ctor : getCtors()) {
2390     if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this,
2391                                    symbolTable)))
2392       return failure();
2393   }
2394   return success();
2395 }
2396 
2397 LogicalResult GlobalCtorsOp::verify() {
2398   if (getCtors().size() != getPriorities().size())
2399     return emitError(
2400         "mismatch between the number of ctors and the number of priorities");
2401   return success();
2402 }
2403 
2404 //===----------------------------------------------------------------------===//
2405 // LLVM::GlobalDtorsOp
2406 //===----------------------------------------------------------------------===//
2407 
2408 LogicalResult
2409 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2410   for (Attribute dtor : getDtors()) {
2411     if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this,
2412                                    symbolTable)))
2413       return failure();
2414   }
2415   return success();
2416 }
2417 
2418 LogicalResult GlobalDtorsOp::verify() {
2419   if (getDtors().size() != getPriorities().size())
2420     return emitError(
2421         "mismatch between the number of dtors and the number of priorities");
2422   return success();
2423 }
2424 
2425 //===----------------------------------------------------------------------===//
2426 // ShuffleVectorOp
2427 //===----------------------------------------------------------------------===//
2428 
2429 void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2430                             Value v2, DenseI32ArrayAttr mask,
2431                             ArrayRef<NamedAttribute> attrs) {
2432   auto containerType = v1.getType();
2433   auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
2434                                    mask.size(),
2435                                    LLVM::isScalableVectorType(containerType));
2436   build(builder, state, vType, v1, v2, mask);
2437   state.addAttributes(attrs);
2438 }
2439 
2440 void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2441                             Value v2, ArrayRef<int32_t> mask) {
2442   build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask));
2443 }
2444 
2445 /// Build the result type of a shuffle vector operation.
2446 static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
2447                                     Type &resType, DenseI32ArrayAttr mask) {
2448   if (!LLVM::isCompatibleVectorType(v1Type))
2449     return parser.emitError(parser.getCurrentLocation(),
2450                             "expected an LLVM compatible vector type");
2451   resType = LLVM::getVectorType(LLVM::getVectorElementType(v1Type), mask.size(),
2452                                 LLVM::isScalableVectorType(v1Type));
2453   return success();
2454 }
2455 
2456 /// Nothing to do when the result type is inferred.
2457 static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type,
2458                              Type resType, DenseI32ArrayAttr mask) {}
2459 
2460 LogicalResult ShuffleVectorOp::verify() {
2461   if (LLVM::isScalableVectorType(getV1().getType()) &&
2462       llvm::any_of(getMask(), [](int32_t v) { return v != 0; }))
2463     return emitOpError("expected a splat operation for scalable vectors");
2464   return success();
2465 }
2466 
2467 //===----------------------------------------------------------------------===//
2468 // Implementations for LLVM::LLVMFuncOp.
2469 //===----------------------------------------------------------------------===//
2470 
2471 // Add the entry block to the function.
2472 Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) {
2473   assert(empty() && "function already has an entry block");
2474   OpBuilder::InsertionGuard g(builder);
2475   Block *entry = builder.createBlock(&getBody());
2476 
2477   // FIXME: Allow passing in proper locations for the entry arguments.
2478   LLVMFunctionType type = getFunctionType();
2479   for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
2480     entry->addArgument(type.getParamType(i), getLoc());
2481   return entry;
2482 }
2483 
2484 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
2485                        StringRef name, Type type, LLVM::Linkage linkage,
2486                        bool dsoLocal, CConv cconv, SymbolRefAttr comdat,
2487                        ArrayRef<NamedAttribute> attrs,
2488                        ArrayRef<DictionaryAttr> argAttrs,
2489                        std::optional<uint64_t> functionEntryCount) {
2490   result.addRegion();
2491   result.addAttribute(SymbolTable::getSymbolAttrName(),
2492                       builder.getStringAttr(name));
2493   result.addAttribute(getFunctionTypeAttrName(result.name),
2494                       TypeAttr::get(type));
2495   result.addAttribute(getLinkageAttrName(result.name),
2496                       LinkageAttr::get(builder.getContext(), linkage));
2497   result.addAttribute(getCConvAttrName(result.name),
2498                       CConvAttr::get(builder.getContext(), cconv));
2499   result.attributes.append(attrs.begin(), attrs.end());
2500   if (dsoLocal)
2501     result.addAttribute(getDsoLocalAttrName(result.name),
2502                         builder.getUnitAttr());
2503   if (comdat)
2504     result.addAttribute(getComdatAttrName(result.name), comdat);
2505   if (functionEntryCount)
2506     result.addAttribute(getFunctionEntryCountAttrName(result.name),
2507                         builder.getI64IntegerAttr(functionEntryCount.value()));
2508   if (argAttrs.empty())
2509     return;
2510 
2511   assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
2512          "expected as many argument attribute lists as arguments");
2513   function_interface_impl::addArgAndResultAttrs(
2514       builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
2515       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2516 }
2517 
2518 // Builds an LLVM function type from the given lists of input and output types.
2519 // Returns a null type if any of the types provided are non-LLVM types, or if
2520 // there is more than one output type.
2521 static Type
2522 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
2523                       ArrayRef<Type> outputs,
2524                       function_interface_impl::VariadicFlag variadicFlag) {
2525   Builder &b = parser.getBuilder();
2526   if (outputs.size() > 1) {
2527     parser.emitError(loc, "failed to construct function type: expected zero or "
2528                           "one function result");
2529     return {};
2530   }
2531 
2532   // Convert inputs to LLVM types, exit early on error.
2533   SmallVector<Type, 4> llvmInputs;
2534   for (auto t : inputs) {
2535     if (!isCompatibleType(t)) {
2536       parser.emitError(loc, "failed to construct function type: expected LLVM "
2537                             "type for function arguments");
2538       return {};
2539     }
2540     llvmInputs.push_back(t);
2541   }
2542 
2543   // No output is denoted as "void" in LLVM type system.
2544   Type llvmOutput =
2545       outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
2546   if (!isCompatibleType(llvmOutput)) {
2547     parser.emitError(loc, "failed to construct function type: expected LLVM "
2548                           "type for function results")
2549         << llvmOutput;
2550     return {};
2551   }
2552   return LLVMFunctionType::get(llvmOutput, llvmInputs,
2553                                variadicFlag.isVariadic());
2554 }
2555 
2556 // Parses an LLVM function.
2557 //
2558 // operation ::= `llvm.func` linkage? cconv? function-signature
2559 //                (`comdat(` symbol-ref-id `)`)?
2560 //                function-attributes?
2561 //                function-body
2562 //
2563 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
2564   // Default to external linkage if no keyword is provided.
2565   result.addAttribute(
2566       getLinkageAttrName(result.name),
2567       LinkageAttr::get(parser.getContext(),
2568                        parseOptionalLLVMKeyword<Linkage>(
2569                            parser, result, LLVM::Linkage::External)));
2570 
2571   // Parse optional visibility, default to Default.
2572   result.addAttribute(getVisibility_AttrName(result.name),
2573                       parser.getBuilder().getI64IntegerAttr(
2574                           parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
2575                               parser, result, LLVM::Visibility::Default)));
2576 
2577   // Parse optional UnnamedAddr, default to None.
2578   result.addAttribute(getUnnamedAddrAttrName(result.name),
2579                       parser.getBuilder().getI64IntegerAttr(
2580                           parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
2581                               parser, result, LLVM::UnnamedAddr::None)));
2582 
2583   // Default to C Calling Convention if no keyword is provided.
2584   result.addAttribute(
2585       getCConvAttrName(result.name),
2586       CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
2587                                               parser, result, LLVM::CConv::C)));
2588 
2589   StringAttr nameAttr;
2590   SmallVector<OpAsmParser::Argument> entryArgs;
2591   SmallVector<DictionaryAttr> resultAttrs;
2592   SmallVector<Type> resultTypes;
2593   bool isVariadic;
2594 
2595   auto signatureLocation = parser.getCurrentLocation();
2596   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2597                              result.attributes) ||
2598       function_interface_impl::parseFunctionSignature(
2599           parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
2600           resultAttrs))
2601     return failure();
2602 
2603   SmallVector<Type> argTypes;
2604   for (auto &arg : entryArgs)
2605     argTypes.push_back(arg.type);
2606   auto type =
2607       buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
2608                             function_interface_impl::VariadicFlag(isVariadic));
2609   if (!type)
2610     return failure();
2611   result.addAttribute(getFunctionTypeAttrName(result.name),
2612                       TypeAttr::get(type));
2613 
2614   if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
2615     int64_t minRange, maxRange;
2616     if (parser.parseLParen() || parser.parseInteger(minRange) ||
2617         parser.parseComma() || parser.parseInteger(maxRange) ||
2618         parser.parseRParen())
2619       return failure();
2620     auto intTy = IntegerType::get(parser.getContext(), 32);
2621     result.addAttribute(
2622         getVscaleRangeAttrName(result.name),
2623         LLVM::VScaleRangeAttr::get(parser.getContext(),
2624                                    IntegerAttr::get(intTy, minRange),
2625                                    IntegerAttr::get(intTy, maxRange)));
2626   }
2627   // Parse the optional comdat selector.
2628   if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2629     SymbolRefAttr comdat;
2630     if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2631         parser.parseRParen())
2632       return failure();
2633 
2634     result.addAttribute(getComdatAttrName(result.name), comdat);
2635   }
2636 
2637   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
2638     return failure();
2639   function_interface_impl::addArgAndResultAttrs(
2640       parser.getBuilder(), result, entryArgs, resultAttrs,
2641       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2642 
2643   auto *body = result.addRegion();
2644   OptionalParseResult parseResult =
2645       parser.parseOptionalRegion(*body, entryArgs);
2646   return failure(parseResult.has_value() && failed(*parseResult));
2647 }
2648 
2649 // Print the LLVMFuncOp. Collects argument and result types and passes them to
2650 // helper functions. Drops "void" result since it cannot be parsed back. Skips
2651 // the external linkage since it is the default value.
2652 void LLVMFuncOp::print(OpAsmPrinter &p) {
2653   p << ' ';
2654   if (getLinkage() != LLVM::Linkage::External)
2655     p << stringifyLinkage(getLinkage()) << ' ';
2656   StringRef visibility = stringifyVisibility(getVisibility_());
2657   if (!visibility.empty())
2658     p << visibility << ' ';
2659   if (auto unnamedAddr = getUnnamedAddr()) {
2660     StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2661     if (!str.empty())
2662       p << str << ' ';
2663   }
2664   if (getCConv() != LLVM::CConv::C)
2665     p << stringifyCConv(getCConv()) << ' ';
2666 
2667   p.printSymbolName(getName());
2668 
2669   LLVMFunctionType fnType = getFunctionType();
2670   SmallVector<Type, 8> argTypes;
2671   SmallVector<Type, 1> resTypes;
2672   argTypes.reserve(fnType.getNumParams());
2673   for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
2674     argTypes.push_back(fnType.getParamType(i));
2675 
2676   Type returnType = fnType.getReturnType();
2677   if (!llvm::isa<LLVMVoidType>(returnType))
2678     resTypes.push_back(returnType);
2679 
2680   function_interface_impl::printFunctionSignature(p, *this, argTypes,
2681                                                   isVarArg(), resTypes);
2682 
2683   // Print vscale range if present
2684   if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
2685     p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
2686       << vscale->getMaxRange().getInt() << ')';
2687 
2688   // Print the optional comdat selector.
2689   if (auto comdat = getComdat())
2690     p << " comdat(" << *comdat << ')';
2691 
2692   function_interface_impl::printFunctionAttributes(
2693       p, *this,
2694       {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
2695        getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
2696        getComdatAttrName(), getUnnamedAddrAttrName(),
2697        getVscaleRangeAttrName()});
2698 
2699   // Print the body if this is not an external function.
2700   Region &body = getBody();
2701   if (!body.empty()) {
2702     p << ' ';
2703     p.printRegion(body, /*printEntryBlockArgs=*/false,
2704                   /*printBlockTerminators=*/true);
2705   }
2706 }
2707 
2708 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2709 // - functions don't have 'common' linkage
2710 // - external functions have 'external' or 'extern_weak' linkage;
2711 // - vararg is (currently) only supported for external functions;
2712 LogicalResult LLVMFuncOp::verify() {
2713   if (getLinkage() == LLVM::Linkage::Common)
2714     return emitOpError() << "functions cannot have '"
2715                          << stringifyLinkage(LLVM::Linkage::Common)
2716                          << "' linkage";
2717 
2718   if (failed(verifyComdat(*this, getComdat())))
2719     return failure();
2720 
2721   if (isExternal()) {
2722     if (getLinkage() != LLVM::Linkage::External &&
2723         getLinkage() != LLVM::Linkage::ExternWeak)
2724       return emitOpError() << "external functions must have '"
2725                            << stringifyLinkage(LLVM::Linkage::External)
2726                            << "' or '"
2727                            << stringifyLinkage(LLVM::Linkage::ExternWeak)
2728                            << "' linkage";
2729     return success();
2730   }
2731 
2732   // In LLVM IR, these attributes are composed by convention, not by design.
2733   if (isNoInline() && isAlwaysInline())
2734     return emitError("no_inline and always_inline attributes are incompatible");
2735 
2736   if (isOptimizeNone() && !isNoInline())
2737     return emitOpError("with optimize_none must also be no_inline");
2738 
2739   Type landingpadResultTy;
2740   StringRef diagnosticMessage;
2741   bool isLandingpadTypeConsistent =
2742       !walk([&](Operation *op) {
2743          const auto checkType = [&](Type type, StringRef errorMessage) {
2744            if (!landingpadResultTy) {
2745              landingpadResultTy = type;
2746              return WalkResult::advance();
2747            }
2748            if (landingpadResultTy != type) {
2749              diagnosticMessage = errorMessage;
2750              return WalkResult::interrupt();
2751            }
2752            return WalkResult::advance();
2753          };
2754          return TypeSwitch<Operation *, WalkResult>(op)
2755              .Case<LandingpadOp>([&](auto landingpad) {
2756                constexpr StringLiteral errorMessage =
2757                    "'llvm.landingpad' should have a consistent result type "
2758                    "inside a function";
2759                return checkType(landingpad.getType(), errorMessage);
2760              })
2761              .Case<ResumeOp>([&](auto resume) {
2762                constexpr StringLiteral errorMessage =
2763                    "'llvm.resume' should have a consistent input type inside a "
2764                    "function";
2765                return checkType(resume.getValue().getType(), errorMessage);
2766              })
2767              .Default([](auto) { return WalkResult::skip(); });
2768        }).wasInterrupted();
2769   if (!isLandingpadTypeConsistent) {
2770     assert(!diagnosticMessage.empty() &&
2771            "Expecting a non-empty diagnostic message");
2772     return emitError(diagnosticMessage);
2773   }
2774 
2775   return success();
2776 }
2777 
2778 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2779 /// - entry block arguments are of LLVM types.
2780 LogicalResult LLVMFuncOp::verifyRegions() {
2781   if (isExternal())
2782     return success();
2783 
2784   unsigned numArguments = getFunctionType().getNumParams();
2785   Block &entryBlock = front();
2786   for (unsigned i = 0; i < numArguments; ++i) {
2787     Type argType = entryBlock.getArgument(i).getType();
2788     if (!isCompatibleType(argType))
2789       return emitOpError("entry block argument #")
2790              << i << " is not of LLVM type";
2791   }
2792 
2793   return success();
2794 }
2795 
2796 Region *LLVMFuncOp::getCallableRegion() {
2797   if (isExternal())
2798     return nullptr;
2799   return &getBody();
2800 }
2801 
2802 //===----------------------------------------------------------------------===//
2803 // UndefOp.
2804 //===----------------------------------------------------------------------===//
2805 
2806 /// Fold an undef operation to a dedicated undef attribute.
2807 OpFoldResult LLVM::UndefOp::fold(FoldAdaptor) {
2808   return LLVM::UndefAttr::get(getContext());
2809 }
2810 
2811 //===----------------------------------------------------------------------===//
2812 // PoisonOp.
2813 //===----------------------------------------------------------------------===//
2814 
2815 /// Fold a poison operation to a dedicated poison attribute.
2816 OpFoldResult LLVM::PoisonOp::fold(FoldAdaptor) {
2817   return LLVM::PoisonAttr::get(getContext());
2818 }
2819 
2820 //===----------------------------------------------------------------------===//
2821 // ZeroOp.
2822 //===----------------------------------------------------------------------===//
2823 
2824 LogicalResult LLVM::ZeroOp::verify() {
2825   if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
2826     if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
2827       return emitOpError()
2828              << "target extension type does not support zero-initializer";
2829 
2830   return success();
2831 }
2832 
2833 /// Fold a zero operation to a builtin zero attribute when possible and fall
2834 /// back to a dedicated zero attribute.
2835 OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
2836   OpFoldResult result = Builder(getContext()).getZeroAttr(getType());
2837   if (result)
2838     return result;
2839   return LLVM::ZeroAttr::get(getContext());
2840 }
2841 
2842 //===----------------------------------------------------------------------===//
2843 // ConstantOp.
2844 //===----------------------------------------------------------------------===//
2845 
2846 /// Compute the total number of elements in the given type, also taking into
2847 /// account nested types. Supported types are `VectorType`, `LLVMArrayType` and
2848 /// `LLVMFixedVectorType`. Everything else is treated as a scalar.
2849 static int64_t getNumElements(Type t) {
2850   if (auto vecType = dyn_cast<VectorType>(t))
2851     return vecType.getNumElements() * getNumElements(vecType.getElementType());
2852   if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
2853     return arrayType.getNumElements() *
2854            getNumElements(arrayType.getElementType());
2855   if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
2856     return vecType.getNumElements() * getNumElements(vecType.getElementType());
2857   assert(!isa<LLVM::LLVMScalableVectorType>(t) &&
2858          "number of elements of a scalable vector type is unknown");
2859   return 1;
2860 }
2861 
2862 /// Check if the given type is a scalable vector type or a vector/array type
2863 /// that contains a nested scalable vector type.
2864 static bool hasScalableVectorType(Type t) {
2865   if (isa<LLVM::LLVMScalableVectorType>(t))
2866     return true;
2867   if (auto vecType = dyn_cast<VectorType>(t)) {
2868     if (vecType.isScalable())
2869       return true;
2870     return hasScalableVectorType(vecType.getElementType());
2871   }
2872   if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
2873     return hasScalableVectorType(arrayType.getElementType());
2874   if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
2875     return hasScalableVectorType(vecType.getElementType());
2876   return false;
2877 }
2878 
2879 LogicalResult LLVM::ConstantOp::verify() {
2880   if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
2881     auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
2882     if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
2883         !arrayType.getElementType().isInteger(8)) {
2884       return emitOpError() << "expected array type of "
2885                            << sAttr.getValue().size()
2886                            << " i8 elements for the string constant";
2887     }
2888     return success();
2889   }
2890   if (auto structType = dyn_cast<LLVMStructType>(getType())) {
2891     auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
2892     if (!arrayAttr) {
2893       return emitOpError() << "expected array attribute for a struct constant";
2894     }
2895 
2896     ArrayRef<Type> elementTypes = structType.getBody();
2897     if (arrayAttr.size() != elementTypes.size()) {
2898       return emitOpError() << "expected array attribute of size "
2899                            << elementTypes.size();
2900     }
2901     for (auto elementTy : elementTypes) {
2902       if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
2903         return emitOpError() << "expected struct element types to be floating "
2904                                 "point type or integer type";
2905       }
2906     }
2907 
2908     for (size_t i = 0; i < elementTypes.size(); ++i) {
2909       Attribute element = arrayAttr[i];
2910       if (!isa<IntegerAttr, FloatAttr>(element)) {
2911         return emitOpError()
2912                << "expected struct element attribute types to be floating "
2913                   "point type or integer type";
2914       }
2915       auto elementType = cast<TypedAttr>(element).getType();
2916       if (elementType != elementTypes[i]) {
2917         return emitOpError()
2918                << "struct element at index " << i << " is of wrong type";
2919       }
2920     }
2921 
2922     return success();
2923   }
2924   if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
2925     return emitOpError() << "does not support target extension type.";
2926   }
2927 
2928   // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
2929   if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
2930     if (!llvm::isa<IntegerType>(getType()))
2931       return emitOpError() << "expected integer type";
2932   } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
2933     const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
2934     unsigned floatWidth = APFloat::getSizeInBits(sem);
2935     if (auto floatTy = dyn_cast<FloatType>(getType())) {
2936       if (floatTy.getWidth() != floatWidth) {
2937         return emitOpError() << "expected float type of width " << floatWidth;
2938       }
2939     }
2940     // See the comment for getLLVMConstant for more details about why 8-bit
2941     // floats can be represented by integers.
2942     if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
2943       return emitOpError() << "expected integer type of width " << floatWidth;
2944     }
2945   } else if (isa<ElementsAttr, ArrayAttr>(getValue())) {
2946     if (hasScalableVectorType(getType())) {
2947       // The exact number of elements of a scalable vector is unknown, so we
2948       // allow only splat attributes.
2949       auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue());
2950       if (!splatElementsAttr)
2951         return emitOpError()
2952                << "scalable vector type requires a splat attribute";
2953       return success();
2954     }
2955     if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>(
2956             getType()))
2957       return emitOpError() << "expected vector or array type";
2958     // The number of elements of the attribute and the type must match.
2959     int64_t attrNumElements;
2960     if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
2961       attrNumElements = elementsAttr.getNumElements();
2962     else
2963       attrNumElements = cast<ArrayAttr>(getValue()).size();
2964     if (getNumElements(getType()) != attrNumElements)
2965       return emitOpError()
2966              << "type and attribute have a different number of elements: "
2967              << getNumElements(getType()) << " vs. " << attrNumElements;
2968   } else {
2969     return emitOpError()
2970            << "only supports integer, float, string or elements attributes";
2971   }
2972 
2973   return success();
2974 }
2975 
2976 bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
2977   // The value's type must be the same as the provided type.
2978   auto typedAttr = dyn_cast<TypedAttr>(value);
2979   if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
2980     return false;
2981   // The value's type must be an LLVM compatible type.
2982   if (!isCompatibleType(type))
2983     return false;
2984   // TODO: Add support for additional attributes kinds once needed.
2985   return isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
2986 }
2987 
2988 ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
2989                                          Type type, Location loc) {
2990   if (isBuildableWith(value, type))
2991     return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value));
2992   return nullptr;
2993 }
2994 
2995 // Constant op constant-folds to its value.
2996 OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
2997 
2998 //===----------------------------------------------------------------------===//
2999 // AtomicRMWOp
3000 //===----------------------------------------------------------------------===//
3001 
3002 void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
3003                         AtomicBinOp binOp, Value ptr, Value val,
3004                         AtomicOrdering ordering, StringRef syncscope,
3005                         unsigned alignment, bool isVolatile) {
3006   build(builder, state, val.getType(), binOp, ptr, val, ordering,
3007         !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3008         alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
3009         /*access_groups=*/nullptr,
3010         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3011 }
3012 
3013 LogicalResult AtomicRMWOp::verify() {
3014   auto valType = getVal().getType();
3015   if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
3016       getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
3017     if (isCompatibleVectorType(valType)) {
3018       if (isScalableVectorType(valType))
3019         return emitOpError("expected LLVM IR fixed vector type");
3020       Type elemType = getVectorElementType(valType);
3021       if (!isCompatibleFloatingPointType(elemType))
3022         return emitOpError(
3023             "expected LLVM IR floating point type for vector element");
3024     } else if (!isCompatibleFloatingPointType(valType)) {
3025       return emitOpError("expected LLVM IR floating point type");
3026     }
3027   } else if (getBinOp() == AtomicBinOp::xchg) {
3028     DataLayout dataLayout = DataLayout::closest(*this);
3029     if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3030       return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
3031   } else {
3032     auto intType = llvm::dyn_cast<IntegerType>(valType);
3033     unsigned intBitWidth = intType ? intType.getWidth() : 0;
3034     if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
3035         intBitWidth != 64)
3036       return emitOpError("expected LLVM IR integer type");
3037   }
3038 
3039   if (static_cast<unsigned>(getOrdering()) <
3040       static_cast<unsigned>(AtomicOrdering::monotonic))
3041     return emitOpError() << "expected at least '"
3042                          << stringifyAtomicOrdering(AtomicOrdering::monotonic)
3043                          << "' ordering";
3044 
3045   return success();
3046 }
3047 
3048 //===----------------------------------------------------------------------===//
3049 // AtomicCmpXchgOp
3050 //===----------------------------------------------------------------------===//
3051 
3052 /// Returns an LLVM struct type that contains a value type and a boolean type.
3053 static LLVMStructType getValAndBoolStructType(Type valType) {
3054   auto boolType = IntegerType::get(valType.getContext(), 1);
3055   return LLVMStructType::getLiteral(valType.getContext(), {valType, boolType});
3056 }
3057 
3058 void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
3059                             Value ptr, Value cmp, Value val,
3060                             AtomicOrdering successOrdering,
3061                             AtomicOrdering failureOrdering, StringRef syncscope,
3062                             unsigned alignment, bool isWeak, bool isVolatile) {
3063   build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val,
3064         successOrdering, failureOrdering,
3065         !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3066         alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak,
3067         isVolatile, /*access_groups=*/nullptr,
3068         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3069 }
3070 
3071 LogicalResult AtomicCmpXchgOp::verify() {
3072   auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
3073   if (!ptrType)
3074     return emitOpError("expected LLVM IR pointer type for operand #0");
3075   auto valType = getVal().getType();
3076   DataLayout dataLayout = DataLayout::closest(*this);
3077   if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3078     return emitOpError("unexpected LLVM IR type");
3079   if (getSuccessOrdering() < AtomicOrdering::monotonic ||
3080       getFailureOrdering() < AtomicOrdering::monotonic)
3081     return emitOpError("ordering must be at least 'monotonic'");
3082   if (getFailureOrdering() == AtomicOrdering::release ||
3083       getFailureOrdering() == AtomicOrdering::acq_rel)
3084     return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
3085   return success();
3086 }
3087 
3088 //===----------------------------------------------------------------------===//
3089 // FenceOp
3090 //===----------------------------------------------------------------------===//
3091 
3092 void FenceOp::build(OpBuilder &builder, OperationState &state,
3093                     AtomicOrdering ordering, StringRef syncscope) {
3094   build(builder, state, ordering,
3095         syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
3096 }
3097 
3098 LogicalResult FenceOp::verify() {
3099   if (getOrdering() == AtomicOrdering::not_atomic ||
3100       getOrdering() == AtomicOrdering::unordered ||
3101       getOrdering() == AtomicOrdering::monotonic)
3102     return emitOpError("can be given only acquire, release, acq_rel, "
3103                        "and seq_cst orderings");
3104   return success();
3105 }
3106 
3107 //===----------------------------------------------------------------------===//
3108 // Verifier for extension ops
3109 //===----------------------------------------------------------------------===//
3110 
3111 /// Verifies that the given extension operation operates on consistent scalars
3112 /// or vectors, and that the target width is larger than the input width.
3113 template <class ExtOp>
3114 static LogicalResult verifyExtOp(ExtOp op) {
3115   IntegerType inputType, outputType;
3116   if (isCompatibleVectorType(op.getArg().getType())) {
3117     if (!isCompatibleVectorType(op.getResult().getType()))
3118       return op.emitError(
3119           "input type is a vector but output type is an integer");
3120     if (getVectorNumElements(op.getArg().getType()) !=
3121         getVectorNumElements(op.getResult().getType()))
3122       return op.emitError("input and output vectors are of incompatible shape");
3123     // Because this is a CastOp, the element of vectors is guaranteed to be an
3124     // integer.
3125     inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
3126     outputType =
3127         cast<IntegerType>(getVectorElementType(op.getResult().getType()));
3128   } else {
3129     // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
3130     // an integer.
3131     inputType = cast<IntegerType>(op.getArg().getType());
3132     outputType = dyn_cast<IntegerType>(op.getResult().getType());
3133     if (!outputType)
3134       return op.emitError(
3135           "input type is an integer but output type is a vector");
3136   }
3137 
3138   if (outputType.getWidth() <= inputType.getWidth())
3139     return op.emitError("integer width of the output type is smaller or "
3140                         "equal to the integer width of the input type");
3141   return success();
3142 }
3143 
3144 //===----------------------------------------------------------------------===//
3145 // ZExtOp
3146 //===----------------------------------------------------------------------===//
3147 
3148 LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
3149 
3150 OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
3151   auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3152   if (!arg)
3153     return {};
3154 
3155   size_t targetSize = cast<IntegerType>(getType()).getWidth();
3156   return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
3157 }
3158 
3159 //===----------------------------------------------------------------------===//
3160 // SExtOp
3161 //===----------------------------------------------------------------------===//
3162 
3163 LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
3164 
3165 //===----------------------------------------------------------------------===//
3166 // Folder and verifier for LLVM::BitcastOp
3167 //===----------------------------------------------------------------------===//
3168 
3169 /// Folds a cast op that can be chained.
3170 template <typename T>
3171 static OpFoldResult foldChainableCast(T castOp,
3172                                       typename T::FoldAdaptor adaptor) {
3173   // cast(x : T0, T0) -> x
3174   if (castOp.getArg().getType() == castOp.getType())
3175     return castOp.getArg();
3176   if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
3177     // cast(cast(x : T0, T1), T0) -> x
3178     if (prev.getArg().getType() == castOp.getType())
3179       return prev.getArg();
3180     // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
3181     castOp.getArgMutable().set(prev.getArg());
3182     return Value{castOp};
3183   }
3184   return {};
3185 }
3186 
3187 OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
3188   return foldChainableCast(*this, adaptor);
3189 }
3190 
3191 LogicalResult LLVM::BitcastOp::verify() {
3192   auto resultType = llvm::dyn_cast<LLVMPointerType>(
3193       extractVectorElementType(getResult().getType()));
3194   auto sourceType = llvm::dyn_cast<LLVMPointerType>(
3195       extractVectorElementType(getArg().getType()));
3196 
3197   // If one of the types is a pointer (or vector of pointers), then
3198   // both source and result type have to be pointers.
3199   if (static_cast<bool>(resultType) != static_cast<bool>(sourceType))
3200     return emitOpError("can only cast pointers from and to pointers");
3201 
3202   if (!resultType)
3203     return success();
3204 
3205   auto isVector =
3206       llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>;
3207 
3208   // Due to bitcast requiring both operands to be of the same size, it is not
3209   // possible for only one of the two to be a pointer of vectors.
3210   if (isVector(getResult().getType()) && !isVector(getArg().getType()))
3211     return emitOpError("cannot cast pointer to vector of pointers");
3212 
3213   if (!isVector(getResult().getType()) && isVector(getArg().getType()))
3214     return emitOpError("cannot cast vector of pointers to pointer");
3215 
3216   // Bitcast cannot cast between pointers of different address spaces.
3217   // 'llvm.addrspacecast' must be used for this purpose instead.
3218   if (resultType.getAddressSpace() != sourceType.getAddressSpace())
3219     return emitOpError("cannot cast pointers of different address spaces, "
3220                        "use 'llvm.addrspacecast' instead");
3221 
3222   return success();
3223 }
3224 
3225 //===----------------------------------------------------------------------===//
3226 // Folder for LLVM::AddrSpaceCastOp
3227 //===----------------------------------------------------------------------===//
3228 
3229 OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
3230   return foldChainableCast(*this, adaptor);
3231 }
3232 
3233 Value LLVM::AddrSpaceCastOp::getViewSource() { return getArg(); }
3234 
3235 //===----------------------------------------------------------------------===//
3236 // Folder for LLVM::GEPOp
3237 //===----------------------------------------------------------------------===//
3238 
3239 OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
3240   GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
3241                                                  adaptor.getDynamicIndices());
3242 
3243   // gep %x:T, 0 -> %x
3244   if (getBase().getType() == getType() && indices.size() == 1)
3245     if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0]))
3246       if (integer.getValue().isZero())
3247         return getBase();
3248 
3249   // Canonicalize any dynamic indices of constant value to constant indices.
3250   bool changed = false;
3251   SmallVector<GEPArg> gepArgs;
3252   for (auto iter : llvm::enumerate(indices)) {
3253     auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value());
3254     // Constant indices can only be int32_t, so if integer does not fit we
3255     // are forced to keep it dynamic, despite being a constant.
3256     if (!indices.isDynamicIndex(iter.index()) || !integer ||
3257         !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
3258 
3259       PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
3260       if (Value val = llvm::dyn_cast_if_present<Value>(existing))
3261         gepArgs.emplace_back(val);
3262       else
3263         gepArgs.emplace_back(cast<IntegerAttr>(existing).getInt());
3264 
3265       continue;
3266     }
3267 
3268     changed = true;
3269     gepArgs.emplace_back(integer.getInt());
3270   }
3271   if (changed) {
3272     SmallVector<int32_t> rawConstantIndices;
3273     SmallVector<Value> dynamicIndices;
3274     destructureIndices(getElemType(), gepArgs, rawConstantIndices,
3275                        dynamicIndices);
3276 
3277     getDynamicIndicesMutable().assign(dynamicIndices);
3278     setRawConstantIndices(rawConstantIndices);
3279     return Value{*this};
3280   }
3281 
3282   return {};
3283 }
3284 
3285 Value LLVM::GEPOp::getViewSource() { return getBase(); }
3286 
3287 //===----------------------------------------------------------------------===//
3288 // ShlOp
3289 //===----------------------------------------------------------------------===//
3290 
3291 OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
3292   auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
3293   if (!rhs)
3294     return {};
3295 
3296   if (rhs.getValue().getZExtValue() >=
3297       getLhs().getType().getIntOrFloatBitWidth())
3298     return {}; // TODO: Fold into poison.
3299 
3300   auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
3301   if (!lhs)
3302     return {};
3303 
3304   return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
3305 }
3306 
3307 //===----------------------------------------------------------------------===//
3308 // OrOp
3309 //===----------------------------------------------------------------------===//
3310 
3311 OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
3312   auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
3313   if (!lhs)
3314     return {};
3315 
3316   auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
3317   if (!rhs)
3318     return {};
3319 
3320   return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
3321 }
3322 
3323 //===----------------------------------------------------------------------===//
3324 // CallIntrinsicOp
3325 //===----------------------------------------------------------------------===//
3326 
3327 LogicalResult CallIntrinsicOp::verify() {
3328   if (!getIntrin().starts_with("llvm."))
3329     return emitOpError() << "intrinsic name must start with 'llvm.'";
3330   if (failed(verifyOperandBundles(*this)))
3331     return failure();
3332   return success();
3333 }
3334 
3335 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3336                             mlir::StringAttr intrin, mlir::ValueRange args) {
3337   build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
3338         FastmathFlagsAttr{},
3339         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
3340 }
3341 
3342 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3343                             mlir::StringAttr intrin, mlir::ValueRange args,
3344                             mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
3345   build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
3346         fastMathFlags,
3347         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
3348 }
3349 
3350 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3351                             mlir::Type resultType, mlir::StringAttr intrin,
3352                             mlir::ValueRange args) {
3353   build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{},
3354         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
3355 }
3356 
3357 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3358                             mlir::TypeRange resultTypes,
3359                             mlir::StringAttr intrin, mlir::ValueRange args,
3360                             mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
3361   build(builder, state, resultTypes, intrin, args, fastMathFlags,
3362         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
3363 }
3364 
3365 //===----------------------------------------------------------------------===//
3366 // OpAsmDialectInterface
3367 //===----------------------------------------------------------------------===//
3368 
3369 namespace {
3370 struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
3371   using OpAsmDialectInterface::OpAsmDialectInterface;
3372 
3373   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
3374     return TypeSwitch<Attribute, AliasResult>(attr)
3375         .Case<AccessGroupAttr, AliasScopeAttr, AliasScopeDomainAttr,
3376               DIBasicTypeAttr, DICommonBlockAttr, DICompileUnitAttr,
3377               DICompositeTypeAttr, DIDerivedTypeAttr, DIFileAttr,
3378               DIGlobalVariableAttr, DIGlobalVariableExpressionAttr,
3379               DIImportedEntityAttr, DILabelAttr, DILexicalBlockAttr,
3380               DILexicalBlockFileAttr, DILocalVariableAttr, DIModuleAttr,
3381               DINamespaceAttr, DINullTypeAttr, DIStringTypeAttr,
3382               DISubprogramAttr, DISubroutineTypeAttr, LoopAnnotationAttr,
3383               LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr,
3384               LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr,
3385               LoopPipelineAttr, LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr,
3386               TBAATagAttr, TBAATypeDescriptorAttr>([&](auto attr) {
3387           os << decltype(attr)::getMnemonic();
3388           return AliasResult::OverridableAlias;
3389         })
3390         .Default([](Attribute) { return AliasResult::NoAlias; });
3391   }
3392 };
3393 } // namespace
3394 
3395 //===----------------------------------------------------------------------===//
3396 // LinkerOptionsOp
3397 //===----------------------------------------------------------------------===//
3398 
3399 LogicalResult LinkerOptionsOp::verify() {
3400   if (mlir::Operation *parentOp = (*this)->getParentOp();
3401       parentOp && !satisfiesLLVMModule(parentOp))
3402     return emitOpError("must appear at the module level");
3403   return success();
3404 }
3405 
3406 //===----------------------------------------------------------------------===//
3407 // InlineAsmOp
3408 //===----------------------------------------------------------------------===//
3409 
3410 void InlineAsmOp::getEffects(
3411     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3412         &effects) {
3413   if (getHasSideEffects()) {
3414     effects.emplace_back(MemoryEffects::Write::get());
3415     effects.emplace_back(MemoryEffects::Read::get());
3416   }
3417 }
3418 
3419 //===----------------------------------------------------------------------===//
3420 // AssumeOp (intrinsic)
3421 //===----------------------------------------------------------------------===//
3422 
3423 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3424                            mlir::Value cond) {
3425   return build(builder, state, cond, /*op_bundle_operands=*/{},
3426                /*op_bundle_tags=*/ArrayAttr{});
3427 }
3428 
3429 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3430                            Value cond,
3431                            ArrayRef<llvm::OperandBundleDefT<Value>> opBundles) {
3432   SmallVector<ValueRange> opBundleOperands;
3433   SmallVector<Attribute> opBundleTags;
3434   opBundleOperands.reserve(opBundles.size());
3435   opBundleTags.reserve(opBundles.size());
3436 
3437   for (const llvm::OperandBundleDefT<Value> &bundle : opBundles) {
3438     opBundleOperands.emplace_back(bundle.inputs());
3439     opBundleTags.push_back(
3440         StringAttr::get(builder.getContext(), bundle.getTag()));
3441   }
3442 
3443   auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags);
3444   return build(builder, state, cond, opBundleOperands, opBundleTagsAttr);
3445 }
3446 
3447 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3448                            Value cond, llvm::StringRef tag, ValueRange args) {
3449   llvm::OperandBundleDefT<Value> opBundle(
3450       tag.str(), SmallVector<Value>(args.begin(), args.end()));
3451   return build(builder, state, cond, opBundle);
3452 }
3453 
3454 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3455                            Value cond, AssumeAlignTag, Value ptr, Value align) {
3456   return build(builder, state, cond, "align", ValueRange{ptr, align});
3457 }
3458 
3459 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3460                            Value cond, AssumeSeparateStorageTag, Value ptr1,
3461                            Value ptr2) {
3462   return build(builder, state, cond, "separate_storage",
3463                ValueRange{ptr1, ptr2});
3464 }
3465 
3466 LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
3467 
3468 //===----------------------------------------------------------------------===//
3469 // masked_gather (intrinsic)
3470 //===----------------------------------------------------------------------===//
3471 
3472 LogicalResult LLVM::masked_gather::verify() {
3473   auto ptrsVectorType = getPtrs().getType();
3474   Type expectedPtrsVectorType =
3475       LLVM::getVectorType(extractVectorElementType(ptrsVectorType),
3476                           LLVM::getVectorNumElements(getRes().getType()));
3477   // Vector of pointers type should match result vector type, other than the
3478   // element type.
3479   if (ptrsVectorType != expectedPtrsVectorType)
3480     return emitOpError("expected operand #1 type to be ")
3481            << expectedPtrsVectorType;
3482   return success();
3483 }
3484 
3485 //===----------------------------------------------------------------------===//
3486 // masked_scatter (intrinsic)
3487 //===----------------------------------------------------------------------===//
3488 
3489 LogicalResult LLVM::masked_scatter::verify() {
3490   auto ptrsVectorType = getPtrs().getType();
3491   Type expectedPtrsVectorType =
3492       LLVM::getVectorType(extractVectorElementType(ptrsVectorType),
3493                           LLVM::getVectorNumElements(getValue().getType()));
3494   // Vector of pointers type should match value vector type, other than the
3495   // element type.
3496   if (ptrsVectorType != expectedPtrsVectorType)
3497     return emitOpError("expected operand #2 type to be ")
3498            << expectedPtrsVectorType;
3499   return success();
3500 }
3501 
3502 //===----------------------------------------------------------------------===//
3503 // LLVMDialect initialization, type parsing, and registration.
3504 //===----------------------------------------------------------------------===//
3505 
3506 void LLVMDialect::initialize() {
3507   registerAttributes();
3508 
3509   // clang-format off
3510   addTypes<LLVMVoidType,
3511            LLVMPPCFP128Type,
3512            LLVMTokenType,
3513            LLVMLabelType,
3514            LLVMMetadataType>();
3515   // clang-format on
3516   registerTypes();
3517 
3518   addOperations<
3519 #define GET_OP_LIST
3520 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
3521       ,
3522 #define GET_OP_LIST
3523 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
3524       >();
3525 
3526   // Support unknown operations because not all LLVM operations are registered.
3527   allowUnknownOperations();
3528   // clang-format off
3529   addInterfaces<LLVMOpAsmDialectInterface>();
3530   // clang-format on
3531   declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
3532 }
3533 
3534 #define GET_OP_CLASSES
3535 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
3536 
3537 #define GET_OP_CLASSES
3538 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
3539 
3540 LogicalResult LLVMDialect::verifyDataLayoutString(
3541     StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
3542   llvm::Expected<llvm::DataLayout> maybeDataLayout =
3543       llvm::DataLayout::parse(descr);
3544   if (maybeDataLayout)
3545     return success();
3546 
3547   std::string message;
3548   llvm::raw_string_ostream messageStream(message);
3549   llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
3550   reportError("invalid data layout descriptor: " + message);
3551   return failure();
3552 }
3553 
3554 /// Verify LLVM dialect attributes.
3555 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
3556                                                     NamedAttribute attr) {
3557   // If the data layout attribute is present, it must use the LLVM data layout
3558   // syntax. Try parsing it and report errors in case of failure. Users of this
3559   // attribute may assume it is well-formed and can pass it to the (asserting)
3560   // llvm::DataLayout constructor.
3561   if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
3562     return success();
3563   if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue()))
3564     return verifyDataLayoutString(
3565         stringAttr.getValue(),
3566         [op](const Twine &message) { op->emitOpError() << message.str(); });
3567 
3568   return op->emitOpError() << "expected '"
3569                            << LLVM::LLVMDialect::getDataLayoutAttrName()
3570                            << "' to be a string attributes";
3571 }
3572 
3573 LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
3574                                                     Type paramType,
3575                                                     NamedAttribute paramAttr) {
3576   // LLVM attribute may be attached to a result of operation that has not been
3577   // converted to LLVM dialect yet, so the result may have a type with unknown
3578   // representation in LLVM dialect type space. In this case we cannot verify
3579   // whether the attribute may be
3580   bool verifyValueType = isCompatibleType(paramType);
3581   StringAttr name = paramAttr.getName();
3582 
3583   auto checkUnitAttrType = [&]() -> LogicalResult {
3584     if (!llvm::isa<UnitAttr>(paramAttr.getValue()))
3585       return op->emitError() << name << " should be a unit attribute";
3586     return success();
3587   };
3588   auto checkTypeAttrType = [&]() -> LogicalResult {
3589     if (!llvm::isa<TypeAttr>(paramAttr.getValue()))
3590       return op->emitError() << name << " should be a type attribute";
3591     return success();
3592   };
3593   auto checkIntegerAttrType = [&]() -> LogicalResult {
3594     if (!llvm::isa<IntegerAttr>(paramAttr.getValue()))
3595       return op->emitError() << name << " should be an integer attribute";
3596     return success();
3597   };
3598   auto checkPointerType = [&]() -> LogicalResult {
3599     if (!llvm::isa<LLVMPointerType>(paramType))
3600       return op->emitError()
3601              << name << " attribute attached to non-pointer LLVM type";
3602     return success();
3603   };
3604   auto checkIntegerType = [&]() -> LogicalResult {
3605     if (!llvm::isa<IntegerType>(paramType))
3606       return op->emitError()
3607              << name << " attribute attached to non-integer LLVM type";
3608     return success();
3609   };
3610   auto checkPointerTypeMatches = [&]() -> LogicalResult {
3611     if (failed(checkPointerType()))
3612       return failure();
3613 
3614     return success();
3615   };
3616 
3617   // Check a unit attribute that is attached to a pointer value.
3618   if (name == LLVMDialect::getNoAliasAttrName() ||
3619       name == LLVMDialect::getReadonlyAttrName() ||
3620       name == LLVMDialect::getReadnoneAttrName() ||
3621       name == LLVMDialect::getWriteOnlyAttrName() ||
3622       name == LLVMDialect::getNestAttrName() ||
3623       name == LLVMDialect::getNoCaptureAttrName() ||
3624       name == LLVMDialect::getNoFreeAttrName() ||
3625       name == LLVMDialect::getNonNullAttrName()) {
3626     if (failed(checkUnitAttrType()))
3627       return failure();
3628     if (verifyValueType && failed(checkPointerType()))
3629       return failure();
3630     return success();
3631   }
3632 
3633   // Check a type attribute that is attached to a pointer value.
3634   if (name == LLVMDialect::getStructRetAttrName() ||
3635       name == LLVMDialect::getByValAttrName() ||
3636       name == LLVMDialect::getByRefAttrName() ||
3637       name == LLVMDialect::getInAllocaAttrName() ||
3638       name == LLVMDialect::getPreallocatedAttrName()) {
3639     if (failed(checkTypeAttrType()))
3640       return failure();
3641     if (verifyValueType && failed(checkPointerTypeMatches()))
3642       return failure();
3643     return success();
3644   }
3645 
3646   // Check a unit attribute that is attached to an integer value.
3647   if (name == LLVMDialect::getSExtAttrName() ||
3648       name == LLVMDialect::getZExtAttrName()) {
3649     if (failed(checkUnitAttrType()))
3650       return failure();
3651     if (verifyValueType && failed(checkIntegerType()))
3652       return failure();
3653     return success();
3654   }
3655 
3656   // Check an integer attribute that is attached to a pointer value.
3657   if (name == LLVMDialect::getAlignAttrName() ||
3658       name == LLVMDialect::getDereferenceableAttrName() ||
3659       name == LLVMDialect::getDereferenceableOrNullAttrName() ||
3660       name == LLVMDialect::getStackAlignmentAttrName()) {
3661     if (failed(checkIntegerAttrType()))
3662       return failure();
3663     if (verifyValueType && failed(checkPointerType()))
3664       return failure();
3665     return success();
3666   }
3667 
3668   // Check a unit attribute that can be attached to arbitrary types.
3669   if (name == LLVMDialect::getNoUndefAttrName() ||
3670       name == LLVMDialect::getInRegAttrName() ||
3671       name == LLVMDialect::getReturnedAttrName())
3672     return checkUnitAttrType();
3673 
3674   return success();
3675 }
3676 
3677 /// Verify LLVMIR function argument attributes.
3678 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
3679                                                     unsigned regionIdx,
3680                                                     unsigned argIdx,
3681                                                     NamedAttribute argAttr) {
3682   auto funcOp = dyn_cast<FunctionOpInterface>(op);
3683   if (!funcOp)
3684     return success();
3685   Type argType = funcOp.getArgumentTypes()[argIdx];
3686 
3687   return verifyParameterAttribute(op, argType, argAttr);
3688 }
3689 
3690 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
3691                                                        unsigned regionIdx,
3692                                                        unsigned resIdx,
3693                                                        NamedAttribute resAttr) {
3694   auto funcOp = dyn_cast<FunctionOpInterface>(op);
3695   if (!funcOp)
3696     return success();
3697   Type resType = funcOp.getResultTypes()[resIdx];
3698 
3699   // Check to see if this function has a void return with a result attribute
3700   // to it. It isn't clear what semantics we would assign to that.
3701   if (llvm::isa<LLVMVoidType>(resType))
3702     return op->emitError() << "cannot attach result attributes to functions "
3703                               "with a void return";
3704 
3705   // Check to see if this attribute is allowed as a result attribute. Only
3706   // explicitly forbidden LLVM attributes will cause an error.
3707   auto name = resAttr.getName();
3708   if (name == LLVMDialect::getAllocAlignAttrName() ||
3709       name == LLVMDialect::getAllocatedPointerAttrName() ||
3710       name == LLVMDialect::getByValAttrName() ||
3711       name == LLVMDialect::getByRefAttrName() ||
3712       name == LLVMDialect::getInAllocaAttrName() ||
3713       name == LLVMDialect::getNestAttrName() ||
3714       name == LLVMDialect::getNoCaptureAttrName() ||
3715       name == LLVMDialect::getNoFreeAttrName() ||
3716       name == LLVMDialect::getPreallocatedAttrName() ||
3717       name == LLVMDialect::getReadnoneAttrName() ||
3718       name == LLVMDialect::getReadonlyAttrName() ||
3719       name == LLVMDialect::getReturnedAttrName() ||
3720       name == LLVMDialect::getStackAlignmentAttrName() ||
3721       name == LLVMDialect::getStructRetAttrName() ||
3722       name == LLVMDialect::getWriteOnlyAttrName())
3723     return op->emitError() << name << " is not a valid result attribute";
3724   return verifyParameterAttribute(op, resType, resAttr);
3725 }
3726 
3727 Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
3728                                             Type type, Location loc) {
3729   // If this was folded from an operation other than llvm.mlir.constant, it
3730   // should be materialized as such. Note that an llvm.mlir.zero may fold into
3731   // a builtin zero attribute and thus will materialize as a llvm.mlir.constant.
3732   if (auto symbol = dyn_cast<FlatSymbolRefAttr>(value))
3733     if (isa<LLVM::LLVMPointerType>(type))
3734       return builder.create<LLVM::AddressOfOp>(loc, type, symbol);
3735   if (isa<LLVM::UndefAttr>(value))
3736     return builder.create<LLVM::UndefOp>(loc, type);
3737   if (isa<LLVM::PoisonAttr>(value))
3738     return builder.create<LLVM::PoisonOp>(loc, type);
3739   if (isa<LLVM::ZeroAttr>(value))
3740     return builder.create<LLVM::ZeroOp>(loc, type);
3741   // Otherwise try materializing it as a regular llvm.mlir.constant op.
3742   return LLVM::ConstantOp::materialize(builder, value, type, loc);
3743 }
3744 
3745 //===----------------------------------------------------------------------===//
3746 // Utility functions.
3747 //===----------------------------------------------------------------------===//
3748 
3749 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
3750                                      StringRef name, StringRef value,
3751                                      LLVM::Linkage linkage) {
3752   assert(builder.getInsertionBlock() &&
3753          builder.getInsertionBlock()->getParentOp() &&
3754          "expected builder to point to a block constrained in an op");
3755   auto module =
3756       builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
3757   assert(module && "builder points to an op outside of a module");
3758 
3759   // Create the global at the entry of the module.
3760   OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
3761   MLIRContext *ctx = builder.getContext();
3762   auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
3763   auto global = moduleBuilder.create<LLVM::GlobalOp>(
3764       loc, type, /*isConstant=*/true, linkage, name,
3765       builder.getStringAttr(value), /*alignment=*/0);
3766 
3767   LLVMPointerType ptrType = LLVMPointerType::get(ctx);
3768   // Get the pointer to the first character in the global string.
3769   Value globalPtr =
3770       builder.create<LLVM::AddressOfOp>(loc, ptrType, global.getSymNameAttr());
3771   return builder.create<LLVM::GEPOp>(loc, ptrType, type, globalPtr,
3772                                      ArrayRef<GEPArg>{0, 0});
3773 }
3774 
3775 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
3776   return op->hasTrait<OpTrait::SymbolTable>() &&
3777          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
3778 }
3779