1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "TypeDetail.h" 16 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" 17 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" 18 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 19 #include "mlir/IR/Attributes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/BuiltinOps.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/IR/DialectImplementation.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/Interfaces/FunctionImplementation.h" 27 #include "mlir/Transforms/InliningUtils.h" 28 29 #include "llvm/ADT/SCCIterator.h" 30 #include "llvm/ADT/TypeSwitch.h" 31 #include "llvm/AsmParser/Parser.h" 32 #include "llvm/Bitcode/BitcodeReader.h" 33 #include "llvm/Bitcode/BitcodeWriter.h" 34 #include "llvm/IR/Attributes.h" 35 #include "llvm/IR/Function.h" 36 #include "llvm/IR/Type.h" 37 #include "llvm/Support/Error.h" 38 #include "llvm/Support/Mutex.h" 39 #include "llvm/Support/SourceMgr.h" 40 41 #include <numeric> 42 #include <optional> 43 44 using namespace mlir; 45 using namespace mlir::LLVM; 46 using mlir::LLVM::cconv::getMaxEnumValForCConv; 47 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 48 using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind; 49 50 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 51 52 //===----------------------------------------------------------------------===// 53 // Property Helpers 54 //===----------------------------------------------------------------------===// 55 56 //===----------------------------------------------------------------------===// 57 // IntegerOverflowFlags 58 59 namespace mlir { 60 static Attribute convertToAttribute(MLIRContext *ctx, 61 IntegerOverflowFlags flags) { 62 return IntegerOverflowFlagsAttr::get(ctx, flags); 63 } 64 65 static LogicalResult 66 convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr, 67 function_ref<InFlightDiagnostic()> emitError) { 68 auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr); 69 if (!flagsAttr) { 70 return emitError() << "expected 'overflowFlags' attribute to be an " 71 "IntegerOverflowFlagsAttr, but got " 72 << attr; 73 } 74 flags = flagsAttr.getValue(); 75 return success(); 76 } 77 } // namespace mlir 78 79 static ParseResult parseOverflowFlags(AsmParser &p, 80 IntegerOverflowFlags &flags) { 81 if (failed(p.parseOptionalKeyword("overflow"))) { 82 flags = IntegerOverflowFlags::none; 83 return success(); 84 } 85 if (p.parseLess()) 86 return failure(); 87 do { 88 StringRef kw; 89 SMLoc loc = p.getCurrentLocation(); 90 if (p.parseKeyword(&kw)) 91 return failure(); 92 std::optional<IntegerOverflowFlags> flag = 93 symbolizeIntegerOverflowFlags(kw); 94 if (!flag) 95 return p.emitError(loc, 96 "invalid overflow flag: expected nsw, nuw, or none"); 97 flags = flags | *flag; 98 } while (succeeded(p.parseOptionalComma())); 99 return p.parseGreater(); 100 } 101 102 static void printOverflowFlags(AsmPrinter &p, Operation *op, 103 IntegerOverflowFlags flags) { 104 if (flags == IntegerOverflowFlags::none) 105 return; 106 p << " overflow<"; 107 SmallVector<StringRef, 2> strs; 108 if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw)) 109 strs.push_back("nsw"); 110 if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw)) 111 strs.push_back("nuw"); 112 llvm::interleaveComma(strs, p); 113 p << ">"; 114 } 115 116 //===----------------------------------------------------------------------===// 117 // Attribute Helpers 118 //===----------------------------------------------------------------------===// 119 120 static constexpr const char kElemTypeAttrName[] = "elem_type"; 121 122 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 123 SmallVector<NamedAttribute, 8> filteredAttrs( 124 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 125 if (attr.getName() == "fastmathFlags") { 126 auto defAttr = 127 FastmathFlagsAttr::get(attr.getValue().getContext(), {}); 128 return defAttr != attr.getValue(); 129 } 130 return true; 131 })); 132 return filteredAttrs; 133 } 134 135 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 136 /// fully defined llvm.func. 137 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 138 Operation *op, 139 SymbolTableCollection &symbolTable) { 140 StringRef name = symbol.getValue(); 141 auto func = 142 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 143 if (!func) 144 return op->emitOpError("'") 145 << name << "' does not reference a valid LLVM function"; 146 if (func.isExternal()) 147 return op->emitOpError("'") << name << "' does not have a definition"; 148 return success(); 149 } 150 151 /// Returns a boolean type that has the same shape as `type`. It supports both 152 /// fixed size vectors as well as scalable vectors. 153 static Type getI1SameShape(Type type) { 154 Type i1Type = IntegerType::get(type.getContext(), 1); 155 if (LLVM::isCompatibleVectorType(type)) 156 return LLVM::getVectorType(i1Type, LLVM::getVectorNumElements(type)); 157 return i1Type; 158 } 159 160 // Parses one of the keywords provided in the list `keywords` and returns the 161 // position of the parsed keyword in the list. If none of the keywords from the 162 // list is parsed, returns -1. 163 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 164 ArrayRef<StringRef> keywords) { 165 for (const auto &en : llvm::enumerate(keywords)) { 166 if (succeeded(parser.parseOptionalKeyword(en.value()))) 167 return en.index(); 168 } 169 return -1; 170 } 171 172 namespace { 173 template <typename Ty> 174 struct EnumTraits {}; 175 176 #define REGISTER_ENUM_TYPE(Ty) \ 177 template <> \ 178 struct EnumTraits<Ty> { \ 179 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 180 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 181 } 182 183 REGISTER_ENUM_TYPE(Linkage); 184 REGISTER_ENUM_TYPE(UnnamedAddr); 185 REGISTER_ENUM_TYPE(CConv); 186 REGISTER_ENUM_TYPE(TailCallKind); 187 REGISTER_ENUM_TYPE(Visibility); 188 } // namespace 189 190 /// Parse an enum from the keyword, or default to the provided default value. 191 /// The return type is the enum type by default, unless overridden with the 192 /// second template argument. 193 template <typename EnumTy, typename RetTy = EnumTy> 194 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 195 OperationState &result, 196 EnumTy defaultValue) { 197 SmallVector<StringRef, 10> names; 198 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 199 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 200 201 int index = parseOptionalKeywordAlternative(parser, names); 202 if (index == -1) 203 return static_cast<RetTy>(defaultValue); 204 return static_cast<RetTy>(index); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // Operand bundle helpers. 209 //===----------------------------------------------------------------------===// 210 211 static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands, 212 TypeRange operandTypes, StringRef tag) { 213 p.printString(tag); 214 p << "("; 215 216 if (!operands.empty()) { 217 p.printOperands(operands); 218 p << " : "; 219 llvm::interleaveComma(operandTypes, p); 220 } 221 222 p << ")"; 223 } 224 225 static void printOpBundles(OpAsmPrinter &p, Operation *op, 226 OperandRangeRange opBundleOperands, 227 TypeRangeRange opBundleOperandTypes, 228 std::optional<ArrayAttr> opBundleTags) { 229 if (opBundleOperands.empty()) 230 return; 231 assert(opBundleTags && "expect operand bundle tags"); 232 233 p << "["; 234 llvm::interleaveComma( 235 llvm::zip(opBundleOperands, opBundleOperandTypes, *opBundleTags), p, 236 [&p](auto bundle) { 237 auto bundleTag = cast<StringAttr>(std::get<2>(bundle)).getValue(); 238 printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle), 239 bundleTag); 240 }); 241 p << "]"; 242 } 243 244 static ParseResult parseOneOpBundle( 245 OpAsmParser &p, 246 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands, 247 SmallVector<SmallVector<Type>> &opBundleOperandTypes, 248 SmallVector<Attribute> &opBundleTags) { 249 SMLoc currentParserLoc = p.getCurrentLocation(); 250 SmallVector<OpAsmParser::UnresolvedOperand> operands; 251 SmallVector<Type> types; 252 std::string tag; 253 254 if (p.parseString(&tag)) 255 return p.emitError(currentParserLoc, "expect operand bundle tag"); 256 257 if (p.parseLParen()) 258 return failure(); 259 260 if (p.parseOptionalRParen()) { 261 if (p.parseOperandList(operands) || p.parseColon() || 262 p.parseTypeList(types) || p.parseRParen()) 263 return failure(); 264 } 265 266 opBundleOperands.push_back(std::move(operands)); 267 opBundleOperandTypes.push_back(std::move(types)); 268 opBundleTags.push_back(StringAttr::get(p.getContext(), tag)); 269 270 return success(); 271 } 272 273 static std::optional<ParseResult> parseOpBundles( 274 OpAsmParser &p, 275 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands, 276 SmallVector<SmallVector<Type>> &opBundleOperandTypes, 277 ArrayAttr &opBundleTags) { 278 if (p.parseOptionalLSquare()) 279 return std::nullopt; 280 281 if (succeeded(p.parseOptionalRSquare())) 282 return success(); 283 284 SmallVector<Attribute> opBundleTagAttrs; 285 auto bundleParser = [&] { 286 return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes, 287 opBundleTagAttrs); 288 }; 289 if (p.parseCommaSeparatedList(bundleParser)) 290 return failure(); 291 292 if (p.parseRSquare()) 293 return failure(); 294 295 opBundleTags = ArrayAttr::get(p.getContext(), opBundleTagAttrs); 296 297 return success(); 298 } 299 300 //===----------------------------------------------------------------------===// 301 // Printing, parsing, folding and builder for LLVM::CmpOp. 302 //===----------------------------------------------------------------------===// 303 304 void ICmpOp::print(OpAsmPrinter &p) { 305 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) 306 << ", " << getOperand(1); 307 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); 308 p << " : " << getLhs().getType(); 309 } 310 311 void FCmpOp::print(OpAsmPrinter &p) { 312 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) 313 << ", " << getOperand(1); 314 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); 315 p << " : " << getLhs().getType(); 316 } 317 318 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 319 // attribute-dict? `:` type 320 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 321 // attribute-dict? `:` type 322 template <typename CmpPredicateType> 323 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 324 StringAttr predicateAttr; 325 OpAsmParser::UnresolvedOperand lhs, rhs; 326 Type type; 327 SMLoc predicateLoc, trailingTypeLoc; 328 if (parser.getCurrentLocation(&predicateLoc) || 329 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 330 parser.parseOperand(lhs) || parser.parseComma() || 331 parser.parseOperand(rhs) || 332 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 333 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 334 parser.resolveOperand(lhs, type, result.operands) || 335 parser.resolveOperand(rhs, type, result.operands)) 336 return failure(); 337 338 // Replace the string attribute `predicate` with an integer attribute. 339 int64_t predicateValue = 0; 340 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 341 std::optional<ICmpPredicate> predicate = 342 symbolizeICmpPredicate(predicateAttr.getValue()); 343 if (!predicate) 344 return parser.emitError(predicateLoc) 345 << "'" << predicateAttr.getValue() 346 << "' is an incorrect value of the 'predicate' attribute"; 347 predicateValue = static_cast<int64_t>(*predicate); 348 } else { 349 std::optional<FCmpPredicate> predicate = 350 symbolizeFCmpPredicate(predicateAttr.getValue()); 351 if (!predicate) 352 return parser.emitError(predicateLoc) 353 << "'" << predicateAttr.getValue() 354 << "' is an incorrect value of the 'predicate' attribute"; 355 predicateValue = static_cast<int64_t>(*predicate); 356 } 357 358 result.attributes.set("predicate", 359 parser.getBuilder().getI64IntegerAttr(predicateValue)); 360 361 // The result type is either i1 or a vector type <? x i1> if the inputs are 362 // vectors. 363 if (!isCompatibleType(type)) 364 return parser.emitError(trailingTypeLoc, 365 "expected LLVM dialect-compatible type"); 366 result.addTypes(getI1SameShape(type)); 367 return success(); 368 } 369 370 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { 371 return parseCmpOp<ICmpPredicate>(parser, result); 372 } 373 374 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { 375 return parseCmpOp<FCmpPredicate>(parser, result); 376 } 377 378 /// Returns a scalar or vector boolean attribute of the given type. 379 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 380 auto boolAttr = BoolAttr::get(ctx, value); 381 ShapedType shapedType = dyn_cast<ShapedType>(type); 382 if (!shapedType) 383 return boolAttr; 384 return DenseElementsAttr::get(shapedType, boolAttr); 385 } 386 387 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) { 388 if (getPredicate() != ICmpPredicate::eq && 389 getPredicate() != ICmpPredicate::ne) 390 return {}; 391 392 // cmpi(eq/ne, x, x) -> true/false 393 if (getLhs() == getRhs()) 394 return getBoolAttribute(getType(), getContext(), 395 getPredicate() == ICmpPredicate::eq); 396 397 // cmpi(eq/ne, alloca, null) -> false/true 398 if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>()) 399 return getBoolAttribute(getType(), getContext(), 400 getPredicate() == ICmpPredicate::ne); 401 402 // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null) 403 if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) { 404 Value lhs = getLhs(); 405 Value rhs = getRhs(); 406 getLhsMutable().assign(rhs); 407 getRhsMutable().assign(lhs); 408 return getResult(); 409 } 410 411 return {}; 412 } 413 414 //===----------------------------------------------------------------------===// 415 // Printing, parsing and verification for LLVM::AllocaOp. 416 //===----------------------------------------------------------------------===// 417 418 void AllocaOp::print(OpAsmPrinter &p) { 419 auto funcTy = 420 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); 421 422 if (getInalloca()) 423 p << " inalloca"; 424 425 p << ' ' << getArraySize() << " x " << getElemType(); 426 if (getAlignment() && *getAlignment() != 0) 427 p.printOptionalAttrDict((*this)->getAttrs(), 428 {kElemTypeAttrName, getInallocaAttrName()}); 429 else 430 p.printOptionalAttrDict( 431 (*this)->getAttrs(), 432 {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()}); 433 p << " : " << funcTy; 434 } 435 436 // <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type 437 // attribute-dict? `:` type `,` type 438 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { 439 OpAsmParser::UnresolvedOperand arraySize; 440 Type type, elemType; 441 SMLoc trailingTypeLoc; 442 443 if (succeeded(parser.parseOptionalKeyword("inalloca"))) 444 result.addAttribute(getInallocaAttrName(result.name), 445 UnitAttr::get(parser.getContext())); 446 447 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 448 parser.parseType(elemType) || 449 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 450 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 451 return failure(); 452 453 std::optional<NamedAttribute> alignmentAttr = 454 result.attributes.getNamed("alignment"); 455 if (alignmentAttr.has_value()) { 456 auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue()); 457 if (!alignmentInt) 458 return parser.emitError(parser.getNameLoc(), 459 "expected integer alignment"); 460 if (alignmentInt.getValue().isZero()) 461 result.attributes.erase("alignment"); 462 } 463 464 // Extract the result type from the trailing function type. 465 auto funcType = llvm::dyn_cast<FunctionType>(type); 466 if (!funcType || funcType.getNumInputs() != 1 || 467 funcType.getNumResults() != 1) 468 return parser.emitError( 469 trailingTypeLoc, 470 "expected trailing function type with one argument and one result"); 471 472 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 473 return failure(); 474 475 Type resultType = funcType.getResult(0); 476 if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType)) 477 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); 478 479 result.addTypes({funcType.getResult(0)}); 480 return success(); 481 } 482 483 LogicalResult AllocaOp::verify() { 484 // Only certain target extension types can be used in 'alloca'. 485 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType()); 486 targetExtType && !targetExtType.supportsMemOps()) 487 return emitOpError() 488 << "this target extension type cannot be used in alloca"; 489 490 return success(); 491 } 492 493 //===----------------------------------------------------------------------===// 494 // LLVM::BrOp 495 //===----------------------------------------------------------------------===// 496 497 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { 498 assert(index == 0 && "invalid successor index"); 499 return SuccessorOperands(getDestOperandsMutable()); 500 } 501 502 //===----------------------------------------------------------------------===// 503 // LLVM::CondBrOp 504 //===----------------------------------------------------------------------===// 505 506 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { 507 assert(index < getNumSuccessors() && "invalid successor index"); 508 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() 509 : getFalseDestOperandsMutable()); 510 } 511 512 void CondBrOp::build(OpBuilder &builder, OperationState &result, 513 Value condition, Block *trueDest, ValueRange trueOperands, 514 Block *falseDest, ValueRange falseOperands, 515 std::optional<std::pair<uint32_t, uint32_t>> weights) { 516 DenseI32ArrayAttr weightsAttr; 517 if (weights) 518 weightsAttr = 519 builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first), 520 static_cast<int32_t>(weights->second)}); 521 522 build(builder, result, condition, trueOperands, falseOperands, weightsAttr, 523 /*loop_annotation=*/{}, trueDest, falseDest); 524 } 525 526 //===----------------------------------------------------------------------===// 527 // LLVM::SwitchOp 528 //===----------------------------------------------------------------------===// 529 530 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 531 Block *defaultDestination, ValueRange defaultOperands, 532 DenseIntElementsAttr caseValues, 533 BlockRange caseDestinations, 534 ArrayRef<ValueRange> caseOperands, 535 ArrayRef<int32_t> branchWeights) { 536 DenseI32ArrayAttr weightsAttr; 537 if (!branchWeights.empty()) 538 weightsAttr = builder.getDenseI32ArrayAttr(branchWeights); 539 540 build(builder, result, value, defaultOperands, caseOperands, caseValues, 541 weightsAttr, defaultDestination, caseDestinations); 542 } 543 544 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 545 Block *defaultDestination, ValueRange defaultOperands, 546 ArrayRef<APInt> caseValues, BlockRange caseDestinations, 547 ArrayRef<ValueRange> caseOperands, 548 ArrayRef<int32_t> branchWeights) { 549 DenseIntElementsAttr caseValuesAttr; 550 if (!caseValues.empty()) { 551 ShapedType caseValueType = VectorType::get( 552 static_cast<int64_t>(caseValues.size()), value.getType()); 553 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 554 } 555 556 build(builder, result, value, defaultDestination, defaultOperands, 557 caseValuesAttr, caseDestinations, caseOperands, branchWeights); 558 } 559 560 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 561 Block *defaultDestination, ValueRange defaultOperands, 562 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 563 ArrayRef<ValueRange> caseOperands, 564 ArrayRef<int32_t> branchWeights) { 565 DenseIntElementsAttr caseValuesAttr; 566 if (!caseValues.empty()) { 567 ShapedType caseValueType = VectorType::get( 568 static_cast<int64_t>(caseValues.size()), value.getType()); 569 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 570 } 571 572 build(builder, result, value, defaultDestination, defaultOperands, 573 caseValuesAttr, caseDestinations, caseOperands, branchWeights); 574 } 575 576 /// <cases> ::= `[` (case (`,` case )* )? `]` 577 /// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 578 static ParseResult parseSwitchOpCases( 579 OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues, 580 SmallVectorImpl<Block *> &caseDestinations, 581 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 582 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 583 if (failed(parser.parseLSquare())) 584 return failure(); 585 if (succeeded(parser.parseOptionalRSquare())) 586 return success(); 587 SmallVector<APInt> values; 588 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 589 auto parseCase = [&]() { 590 int64_t value = 0; 591 if (failed(parser.parseInteger(value))) 592 return failure(); 593 values.push_back(APInt(bitWidth, value, /*isSigned=*/true)); 594 595 Block *destination; 596 SmallVector<OpAsmParser::UnresolvedOperand> operands; 597 SmallVector<Type> operandTypes; 598 if (parser.parseColon() || parser.parseSuccessor(destination)) 599 return failure(); 600 if (!parser.parseOptionalLParen()) { 601 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 602 /*allowResultNumber=*/false) || 603 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 604 return failure(); 605 } 606 caseDestinations.push_back(destination); 607 caseOperands.emplace_back(operands); 608 caseOperandTypes.emplace_back(operandTypes); 609 return success(); 610 }; 611 if (failed(parser.parseCommaSeparatedList(parseCase))) 612 return failure(); 613 614 ShapedType caseValueType = 615 VectorType::get(static_cast<int64_t>(values.size()), flagType); 616 caseValues = DenseIntElementsAttr::get(caseValueType, values); 617 return parser.parseRSquare(); 618 } 619 620 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, 621 DenseIntElementsAttr caseValues, 622 SuccessorRange caseDestinations, 623 OperandRangeRange caseOperands, 624 const TypeRangeRange &caseOperandTypes) { 625 p << '['; 626 p.printNewline(); 627 if (!caseValues) { 628 p << ']'; 629 return; 630 } 631 632 size_t index = 0; 633 llvm::interleave( 634 llvm::zip(caseValues, caseDestinations), 635 [&](auto i) { 636 p << " "; 637 p << std::get<0>(i).getLimitedValue(); 638 p << ": "; 639 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 640 }, 641 [&] { 642 p << ','; 643 p.printNewline(); 644 }); 645 p.printNewline(); 646 p << ']'; 647 } 648 649 LogicalResult SwitchOp::verify() { 650 if ((!getCaseValues() && !getCaseDestinations().empty()) || 651 (getCaseValues() && 652 getCaseValues()->size() != 653 static_cast<int64_t>(getCaseDestinations().size()))) 654 return emitOpError("expects number of case values to match number of " 655 "case destinations"); 656 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) 657 return emitError("expects number of branch weights to match number of " 658 "successors: ") 659 << getBranchWeights()->size() << " vs " << getNumSuccessors(); 660 if (getCaseValues() && 661 getValue().getType() != getCaseValues()->getElementType()) 662 return emitError("expects case value type to match condition value type"); 663 return success(); 664 } 665 666 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 667 assert(index < getNumSuccessors() && "invalid successor index"); 668 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 669 : getCaseOperandsMutable(index - 1)); 670 } 671 672 //===----------------------------------------------------------------------===// 673 // Code for LLVM::GEPOp. 674 //===----------------------------------------------------------------------===// 675 676 constexpr int32_t GEPOp::kDynamicIndex; 677 678 GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() { 679 return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(), 680 getDynamicIndices()); 681 } 682 683 /// Returns the elemental type of any LLVM-compatible vector type or self. 684 static Type extractVectorElementType(Type type) { 685 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) 686 return vectorType.getElementType(); 687 if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type)) 688 return scalableVectorType.getElementType(); 689 if (auto fixedVectorType = llvm::dyn_cast<LLVMFixedVectorType>(type)) 690 return fixedVectorType.getElementType(); 691 return type; 692 } 693 694 /// Destructures the 'indices' parameter into 'rawConstantIndices' and 695 /// 'dynamicIndices', encoding the former in the process. In the process, 696 /// dynamic indices which are used to index into a structure type are converted 697 /// to constant indices when possible. To do this, the GEPs element type should 698 /// be passed as first parameter. 699 static void destructureIndices(Type currType, ArrayRef<GEPArg> indices, 700 SmallVectorImpl<int32_t> &rawConstantIndices, 701 SmallVectorImpl<Value> &dynamicIndices) { 702 for (const GEPArg &iter : indices) { 703 // If the thing we are currently indexing into is a struct we must turn 704 // any integer constants into constant indices. If this is not possible 705 // we don't do anything here. The verifier will catch it and emit a proper 706 // error. All other canonicalization is done in the fold method. 707 bool requiresConst = !rawConstantIndices.empty() && 708 isa_and_nonnull<LLVMStructType>(currType); 709 if (Value val = llvm::dyn_cast_if_present<Value>(iter)) { 710 APInt intC; 711 if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) && 712 intC.isSignedIntN(kGEPConstantBitWidth)) { 713 rawConstantIndices.push_back(intC.getSExtValue()); 714 } else { 715 rawConstantIndices.push_back(GEPOp::kDynamicIndex); 716 dynamicIndices.push_back(val); 717 } 718 } else { 719 rawConstantIndices.push_back(cast<GEPConstantIndex>(iter)); 720 } 721 722 // Skip for very first iteration of this loop. First index does not index 723 // within the aggregates, but is just a pointer offset. 724 if (rawConstantIndices.size() == 1 || !currType) 725 continue; 726 727 currType = 728 TypeSwitch<Type, Type>(currType) 729 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 730 LLVMArrayType>([](auto containerType) { 731 return containerType.getElementType(); 732 }) 733 .Case([&](LLVMStructType structType) -> Type { 734 int64_t memberIndex = rawConstantIndices.back(); 735 if (memberIndex >= 0 && static_cast<size_t>(memberIndex) < 736 structType.getBody().size()) 737 return structType.getBody()[memberIndex]; 738 return nullptr; 739 }) 740 .Default(Type(nullptr)); 741 } 742 } 743 744 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 745 Type elementType, Value basePtr, ArrayRef<GEPArg> indices, 746 bool inbounds, ArrayRef<NamedAttribute> attributes) { 747 SmallVector<int32_t> rawConstantIndices; 748 SmallVector<Value> dynamicIndices; 749 destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices); 750 751 result.addTypes(resultType); 752 result.addAttributes(attributes); 753 result.addAttribute(getRawConstantIndicesAttrName(result.name), 754 builder.getDenseI32ArrayAttr(rawConstantIndices)); 755 if (inbounds) { 756 result.addAttribute(getInboundsAttrName(result.name), 757 builder.getUnitAttr()); 758 } 759 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); 760 result.addOperands(basePtr); 761 result.addOperands(dynamicIndices); 762 } 763 764 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 765 Type elementType, Value basePtr, ValueRange indices, 766 bool inbounds, ArrayRef<NamedAttribute> attributes) { 767 build(builder, result, resultType, elementType, basePtr, 768 SmallVector<GEPArg>(indices), inbounds, attributes); 769 } 770 771 static ParseResult 772 parseGEPIndices(OpAsmParser &parser, 773 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices, 774 DenseI32ArrayAttr &rawConstantIndices) { 775 SmallVector<int32_t> constantIndices; 776 777 auto idxParser = [&]() -> ParseResult { 778 int32_t constantIndex; 779 OptionalParseResult parsedInteger = 780 parser.parseOptionalInteger(constantIndex); 781 if (parsedInteger.has_value()) { 782 if (failed(parsedInteger.value())) 783 return failure(); 784 constantIndices.push_back(constantIndex); 785 return success(); 786 } 787 788 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); 789 return parser.parseOperand(indices.emplace_back()); 790 }; 791 if (parser.parseCommaSeparatedList(idxParser)) 792 return failure(); 793 794 rawConstantIndices = 795 DenseI32ArrayAttr::get(parser.getContext(), constantIndices); 796 return success(); 797 } 798 799 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, 800 OperandRange indices, 801 DenseI32ArrayAttr rawConstantIndices) { 802 llvm::interleaveComma( 803 GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer, 804 [&](PointerUnion<IntegerAttr, Value> cst) { 805 if (Value val = llvm::dyn_cast_if_present<Value>(cst)) 806 printer.printOperand(val); 807 else 808 printer << cast<IntegerAttr>(cst).getInt(); 809 }); 810 } 811 812 /// For the given `indices`, check if they comply with `baseGEPType`, 813 /// especially check against LLVMStructTypes nested within. 814 static LogicalResult 815 verifyStructIndices(Type baseGEPType, unsigned indexPos, 816 GEPIndicesAdaptor<ValueRange> indices, 817 function_ref<InFlightDiagnostic()> emitOpError) { 818 if (indexPos >= indices.size()) 819 // Stop searching 820 return success(); 821 822 return TypeSwitch<Type, LogicalResult>(baseGEPType) 823 .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult { 824 auto attr = dyn_cast<IntegerAttr>(indices[indexPos]); 825 if (!attr) 826 return emitOpError() << "expected index " << indexPos 827 << " indexing a struct to be constant"; 828 829 int32_t gepIndex = attr.getInt(); 830 ArrayRef<Type> elementTypes = structType.getBody(); 831 if (gepIndex < 0 || 832 static_cast<size_t>(gepIndex) >= elementTypes.size()) 833 return emitOpError() << "index " << indexPos 834 << " indexing a struct is out of bounds"; 835 836 // Instead of recursively going into every children types, we only 837 // dive into the one indexed by gepIndex. 838 return verifyStructIndices(elementTypes[gepIndex], indexPos + 1, 839 indices, emitOpError); 840 }) 841 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 842 LLVMArrayType>([&](auto containerType) -> LogicalResult { 843 return verifyStructIndices(containerType.getElementType(), indexPos + 1, 844 indices, emitOpError); 845 }) 846 .Default([&](auto otherType) -> LogicalResult { 847 return emitOpError() 848 << "type " << otherType << " cannot be indexed (index #" 849 << indexPos << ")"; 850 }); 851 } 852 853 /// Driver function around `verifyStructIndices`. 854 static LogicalResult 855 verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices, 856 function_ref<InFlightDiagnostic()> emitOpError) { 857 return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError); 858 } 859 860 LogicalResult LLVM::GEPOp::verify() { 861 if (static_cast<size_t>( 862 llvm::count(getRawConstantIndices(), kDynamicIndex)) != 863 getDynamicIndices().size()) 864 return emitOpError("expected as many dynamic indices as specified in '") 865 << getRawConstantIndicesAttrName().getValue() << "'"; 866 867 return verifyStructIndices(getElemType(), getIndices(), 868 [&] { return emitOpError(); }); 869 } 870 871 //===----------------------------------------------------------------------===// 872 // LoadOp 873 //===----------------------------------------------------------------------===// 874 875 void LoadOp::getEffects( 876 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 877 &effects) { 878 effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable()); 879 // Volatile operations can have target-specific read-write effects on 880 // memory besides the one referred to by the pointer operand. 881 // Similarly, atomic operations that are monotonic or stricter cause 882 // synchronization that from a language point-of-view, are arbitrary 883 // read-writes into memory. 884 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic && 885 getOrdering() != AtomicOrdering::unordered)) { 886 effects.emplace_back(MemoryEffects::Write::get()); 887 effects.emplace_back(MemoryEffects::Read::get()); 888 } 889 } 890 891 /// Returns true if the given type is supported by atomic operations. All 892 /// integer, float, and pointer types with a power-of-two bitsize and a minimal 893 /// size of 8 bits are supported. 894 static bool isTypeCompatibleWithAtomicOp(Type type, 895 const DataLayout &dataLayout) { 896 if (!isa<IntegerType, LLVMPointerType>(type)) 897 if (!isCompatibleFloatingPointType(type)) 898 return false; 899 900 llvm::TypeSize bitWidth = dataLayout.getTypeSizeInBits(type); 901 if (bitWidth.isScalable()) 902 return false; 903 // Needs to be at least 8 bits and a power of two. 904 return bitWidth >= 8 && (bitWidth & (bitWidth - 1)) == 0; 905 } 906 907 /// Verifies the attributes and the type of atomic memory access operations. 908 template <typename OpTy> 909 LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType, 910 ArrayRef<AtomicOrdering> unsupportedOrderings) { 911 if (memOp.getOrdering() != AtomicOrdering::not_atomic) { 912 DataLayout dataLayout = DataLayout::closest(memOp); 913 if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout)) 914 return memOp.emitOpError("unsupported type ") 915 << valueType << " for atomic access"; 916 if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering())) 917 return memOp.emitOpError("unsupported ordering '") 918 << stringifyAtomicOrdering(memOp.getOrdering()) << "'"; 919 if (!memOp.getAlignment()) 920 return memOp.emitOpError("expected alignment for atomic access"); 921 return success(); 922 } 923 if (memOp.getSyncscope()) 924 return memOp.emitOpError( 925 "expected syncscope to be null for non-atomic access"); 926 return success(); 927 } 928 929 LogicalResult LoadOp::verify() { 930 Type valueType = getResult().getType(); 931 return verifyAtomicMemOp(*this, valueType, 932 {AtomicOrdering::release, AtomicOrdering::acq_rel}); 933 } 934 935 void LoadOp::build(OpBuilder &builder, OperationState &state, Type type, 936 Value addr, unsigned alignment, bool isVolatile, 937 bool isNonTemporal, bool isInvariant, bool isInvariantGroup, 938 AtomicOrdering ordering, StringRef syncscope) { 939 build(builder, state, type, addr, 940 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile, 941 isNonTemporal, isInvariant, isInvariantGroup, ordering, 942 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope), 943 /*access_groups=*/nullptr, 944 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, 945 /*tbaa=*/nullptr); 946 } 947 948 //===----------------------------------------------------------------------===// 949 // StoreOp 950 //===----------------------------------------------------------------------===// 951 952 void StoreOp::getEffects( 953 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 954 &effects) { 955 effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable()); 956 // Volatile operations can have target-specific read-write effects on 957 // memory besides the one referred to by the pointer operand. 958 // Similarly, atomic operations that are monotonic or stricter cause 959 // synchronization that from a language point-of-view, are arbitrary 960 // read-writes into memory. 961 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic && 962 getOrdering() != AtomicOrdering::unordered)) { 963 effects.emplace_back(MemoryEffects::Write::get()); 964 effects.emplace_back(MemoryEffects::Read::get()); 965 } 966 } 967 968 LogicalResult StoreOp::verify() { 969 Type valueType = getValue().getType(); 970 return verifyAtomicMemOp(*this, valueType, 971 {AtomicOrdering::acquire, AtomicOrdering::acq_rel}); 972 } 973 974 void StoreOp::build(OpBuilder &builder, OperationState &state, Value value, 975 Value addr, unsigned alignment, bool isVolatile, 976 bool isNonTemporal, bool isInvariantGroup, 977 AtomicOrdering ordering, StringRef syncscope) { 978 build(builder, state, value, addr, 979 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile, 980 isNonTemporal, isInvariantGroup, ordering, 981 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope), 982 /*access_groups=*/nullptr, 983 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); 984 } 985 986 //===----------------------------------------------------------------------===// 987 // CallOp 988 //===----------------------------------------------------------------------===// 989 990 /// Gets the MLIR Op-like result types of a LLVMFunctionType. 991 static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) { 992 SmallVector<Type, 1> results; 993 Type resultType = calleeType.getReturnType(); 994 if (!isa<LLVM::LLVMVoidType>(resultType)) 995 results.push_back(resultType); 996 return results; 997 } 998 999 /// Gets the variadic callee type for a LLVMFunctionType. 1000 static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) { 1001 return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr; 1002 } 1003 1004 /// Constructs a LLVMFunctionType from MLIR `results` and `args`. 1005 static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results, 1006 ValueRange args) { 1007 Type resultType; 1008 if (results.empty()) 1009 resultType = LLVMVoidType::get(context); 1010 else 1011 resultType = results.front(); 1012 return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()), 1013 /*isVarArg=*/false); 1014 } 1015 1016 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, 1017 StringRef callee, ValueRange args) { 1018 build(builder, state, results, builder.getStringAttr(callee), args); 1019 } 1020 1021 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, 1022 StringAttr callee, ValueRange args) { 1023 build(builder, state, results, SymbolRefAttr::get(callee), args); 1024 } 1025 1026 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, 1027 FlatSymbolRefAttr callee, ValueRange args) { 1028 assert(callee && "expected non-null callee in direct call builder"); 1029 build(builder, state, results, 1030 /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr, 1031 /*branch_weights=*/nullptr, 1032 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, 1033 /*memory_effects=*/nullptr, 1034 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, 1035 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, 1036 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, 1037 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); 1038 } 1039 1040 void CallOp::build(OpBuilder &builder, OperationState &state, 1041 LLVMFunctionType calleeType, StringRef callee, 1042 ValueRange args) { 1043 build(builder, state, calleeType, builder.getStringAttr(callee), args); 1044 } 1045 1046 void CallOp::build(OpBuilder &builder, OperationState &state, 1047 LLVMFunctionType calleeType, StringAttr callee, 1048 ValueRange args) { 1049 build(builder, state, calleeType, SymbolRefAttr::get(callee), args); 1050 } 1051 1052 void CallOp::build(OpBuilder &builder, OperationState &state, 1053 LLVMFunctionType calleeType, FlatSymbolRefAttr callee, 1054 ValueRange args) { 1055 build(builder, state, getCallOpResultTypes(calleeType), 1056 getCallOpVarCalleeType(calleeType), callee, args, 1057 /*fastmathFlags=*/nullptr, 1058 /*branch_weights=*/nullptr, /*CConv=*/nullptr, 1059 /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, 1060 /*convergent=*/nullptr, 1061 /*no_unwind=*/nullptr, /*will_return=*/nullptr, 1062 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, 1063 /*access_groups=*/nullptr, 1064 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); 1065 } 1066 1067 void CallOp::build(OpBuilder &builder, OperationState &state, 1068 LLVMFunctionType calleeType, ValueRange args) { 1069 build(builder, state, getCallOpResultTypes(calleeType), 1070 getCallOpVarCalleeType(calleeType), 1071 /*callee=*/nullptr, args, 1072 /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, 1073 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, 1074 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, 1075 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, 1076 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, 1077 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); 1078 } 1079 1080 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, 1081 ValueRange args) { 1082 auto calleeType = func.getFunctionType(); 1083 build(builder, state, getCallOpResultTypes(calleeType), 1084 getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args, 1085 /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, 1086 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, 1087 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, 1088 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, 1089 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, 1090 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); 1091 } 1092 1093 CallInterfaceCallable CallOp::getCallableForCallee() { 1094 // Direct call. 1095 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) 1096 return calleeAttr; 1097 // Indirect call, callee Value is the first operand. 1098 return getOperand(0); 1099 } 1100 1101 void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) { 1102 // Direct call. 1103 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) { 1104 auto symRef = cast<SymbolRefAttr>(callee); 1105 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef)); 1106 } 1107 // Indirect call, callee Value is the first operand. 1108 return setOperand(0, cast<Value>(callee)); 1109 } 1110 1111 Operation::operand_range CallOp::getArgOperands() { 1112 return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1); 1113 } 1114 1115 MutableOperandRange CallOp::getArgOperandsMutable() { 1116 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1, 1117 getCalleeOperands().size()); 1118 } 1119 1120 /// Verify that an inlinable callsite of a debug-info-bearing function in a 1121 /// debug-info-bearing function has a debug location attached to it. This 1122 /// mirrors an LLVM IR verifier. 1123 static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) { 1124 if (callee.isExternal()) 1125 return success(); 1126 auto parentFunc = callOp->getParentOfType<FunctionOpInterface>(); 1127 if (!parentFunc) 1128 return success(); 1129 1130 auto hasSubprogram = [](Operation *op) { 1131 return op->getLoc() 1132 ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() != 1133 nullptr; 1134 }; 1135 if (!hasSubprogram(parentFunc) || !hasSubprogram(callee)) 1136 return success(); 1137 bool containsLoc = !isa<UnknownLoc>(callOp->getLoc()); 1138 if (!containsLoc) 1139 return callOp.emitError() 1140 << "inlinable function call in a function with a DISubprogram " 1141 "location must have a debug location"; 1142 return success(); 1143 } 1144 1145 /// Verify that the parameter and return types of the variadic callee type match 1146 /// the `callOp` argument and result types. 1147 template <typename OpTy> 1148 LogicalResult verifyCallOpVarCalleeType(OpTy callOp) { 1149 std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType(); 1150 if (!varCalleeType) 1151 return success(); 1152 1153 // Verify the variadic callee type is a variadic function type. 1154 if (!varCalleeType->isVarArg()) 1155 return callOp.emitOpError( 1156 "expected var_callee_type to be a variadic function type"); 1157 1158 // Verify the variadic callee type has at most as many parameters as the call 1159 // has argument operands. 1160 if (varCalleeType->getNumParams() > callOp.getArgOperands().size()) 1161 return callOp.emitOpError("expected var_callee_type to have at most ") 1162 << callOp.getArgOperands().size() << " parameters"; 1163 1164 // Verify the variadic callee type matches the call argument types. 1165 for (auto [paramType, operand] : 1166 llvm::zip(varCalleeType->getParams(), callOp.getArgOperands())) 1167 if (paramType != operand.getType()) 1168 return callOp.emitOpError() 1169 << "var_callee_type parameter type mismatch: " << paramType 1170 << " != " << operand.getType(); 1171 1172 // Verify the variadic callee type matches the call result type. 1173 if (!callOp.getNumResults()) { 1174 if (!isa<LLVMVoidType>(varCalleeType->getReturnType())) 1175 return callOp.emitOpError("expected var_callee_type to return void"); 1176 } else { 1177 if (callOp.getResult().getType() != varCalleeType->getReturnType()) 1178 return callOp.emitOpError("var_callee_type return type mismatch: ") 1179 << varCalleeType->getReturnType() 1180 << " != " << callOp.getResult().getType(); 1181 } 1182 return success(); 1183 } 1184 1185 template <typename OpType> 1186 static LogicalResult verifyOperandBundles(OpType &op) { 1187 OperandRangeRange opBundleOperands = op.getOpBundleOperands(); 1188 std::optional<ArrayAttr> opBundleTags = op.getOpBundleTags(); 1189 1190 auto isStringAttr = [](Attribute tagAttr) { 1191 return isa<StringAttr>(tagAttr); 1192 }; 1193 if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr)) 1194 return op.emitError("operand bundle tag must be a StringAttr"); 1195 1196 size_t numOpBundles = opBundleOperands.size(); 1197 size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0; 1198 if (numOpBundles != numOpBundleTags) 1199 return op.emitError("expected ") 1200 << numOpBundles << " operand bundle tags, but actually got " 1201 << numOpBundleTags; 1202 1203 return success(); 1204 } 1205 1206 LogicalResult CallOp::verify() { return verifyOperandBundles(*this); } 1207 1208 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1209 if (failed(verifyCallOpVarCalleeType(*this))) 1210 return failure(); 1211 1212 // Type for the callee, we'll get it differently depending if it is a direct 1213 // or indirect call. 1214 Type fnType; 1215 1216 bool isIndirect = false; 1217 1218 // If this is an indirect call, the callee attribute is missing. 1219 FlatSymbolRefAttr calleeName = getCalleeAttr(); 1220 if (!calleeName) { 1221 isIndirect = true; 1222 if (!getNumOperands()) 1223 return emitOpError( 1224 "must have either a `callee` attribute or at least an operand"); 1225 auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType()); 1226 if (!ptrType) 1227 return emitOpError("indirect call expects a pointer as callee: ") 1228 << getOperand(0).getType(); 1229 1230 return success(); 1231 } else { 1232 Operation *callee = 1233 symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr()); 1234 if (!callee) 1235 return emitOpError() 1236 << "'" << calleeName.getValue() 1237 << "' does not reference a symbol in the current scope"; 1238 auto fn = dyn_cast<LLVMFuncOp>(callee); 1239 if (!fn) 1240 return emitOpError() << "'" << calleeName.getValue() 1241 << "' does not reference a valid LLVM function"; 1242 1243 if (failed(verifyCallOpDebugInfo(*this, fn))) 1244 return failure(); 1245 fnType = fn.getFunctionType(); 1246 } 1247 1248 LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType); 1249 if (!funcType) 1250 return emitOpError("callee does not have a functional type: ") << fnType; 1251 1252 if (funcType.isVarArg() && !getVarCalleeType()) 1253 return emitOpError() << "missing var_callee_type attribute for vararg call"; 1254 1255 // Verify that the operand and result types match the callee. 1256 1257 if (!funcType.isVarArg() && 1258 funcType.getNumParams() != (getCalleeOperands().size() - isIndirect)) 1259 return emitOpError() << "incorrect number of operands (" 1260 << (getCalleeOperands().size() - isIndirect) 1261 << ") for callee (expecting: " 1262 << funcType.getNumParams() << ")"; 1263 1264 if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect)) 1265 return emitOpError() << "incorrect number of operands (" 1266 << (getCalleeOperands().size() - isIndirect) 1267 << ") for varargs callee (expecting at least: " 1268 << funcType.getNumParams() << ")"; 1269 1270 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 1271 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 1272 return emitOpError() << "operand type mismatch for operand " << i << ": " 1273 << getOperand(i + isIndirect).getType() 1274 << " != " << funcType.getParamType(i); 1275 1276 if (getNumResults() == 0 && 1277 !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType())) 1278 return emitOpError() << "expected function call to produce a value"; 1279 1280 if (getNumResults() != 0 && 1281 llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType())) 1282 return emitOpError() 1283 << "calling function with void result must not produce values"; 1284 1285 if (getNumResults() > 1) 1286 return emitOpError() 1287 << "expected LLVM function call to produce 0 or 1 result"; 1288 1289 if (getNumResults() && getResult().getType() != funcType.getReturnType()) 1290 return emitOpError() << "result type mismatch: " << getResult().getType() 1291 << " != " << funcType.getReturnType(); 1292 1293 return success(); 1294 } 1295 1296 void CallOp::print(OpAsmPrinter &p) { 1297 auto callee = getCallee(); 1298 bool isDirect = callee.has_value(); 1299 1300 p << ' '; 1301 1302 // Print calling convention. 1303 if (getCConv() != LLVM::CConv::C) 1304 p << stringifyCConv(getCConv()) << ' '; 1305 1306 if (getTailCallKind() != LLVM::TailCallKind::None) 1307 p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' '; 1308 1309 // Print the direct callee if present as a function attribute, or an indirect 1310 // callee (first operand) otherwise. 1311 if (isDirect) 1312 p.printSymbolName(callee.value()); 1313 else 1314 p << getOperand(0); 1315 1316 auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1); 1317 p << '(' << args << ')'; 1318 1319 // Print the variadic callee type if the call is variadic. 1320 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType()) 1321 p << " vararg(" << *varCalleeType << ")"; 1322 1323 if (!getOpBundleOperands().empty()) { 1324 p << " "; 1325 printOpBundles(p, *this, getOpBundleOperands(), 1326 getOpBundleOperands().getTypes(), getOpBundleTags()); 1327 } 1328 1329 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), 1330 {getCalleeAttrName(), getTailCallKindAttrName(), 1331 getVarCalleeTypeAttrName(), getCConvAttrName(), 1332 getOperandSegmentSizesAttrName(), 1333 getOpBundleSizesAttrName(), 1334 getOpBundleTagsAttrName()}); 1335 1336 p << " : "; 1337 if (!isDirect) 1338 p << getOperand(0).getType() << ", "; 1339 1340 // Reconstruct the function MLIR function type from operand and result types. 1341 p.printFunctionalType(args.getTypes(), getResultTypes()); 1342 } 1343 1344 /// Parses the type of a call operation and resolves the operands if the parsing 1345 /// succeeds. Returns failure otherwise. 1346 static ParseResult parseCallTypeAndResolveOperands( 1347 OpAsmParser &parser, OperationState &result, bool isDirect, 1348 ArrayRef<OpAsmParser::UnresolvedOperand> operands) { 1349 SMLoc trailingTypesLoc = parser.getCurrentLocation(); 1350 SmallVector<Type> types; 1351 if (parser.parseColonTypeList(types)) 1352 return failure(); 1353 1354 if (isDirect && types.size() != 1) 1355 return parser.emitError(trailingTypesLoc, 1356 "expected direct call to have 1 trailing type"); 1357 if (!isDirect && types.size() != 2) 1358 return parser.emitError(trailingTypesLoc, 1359 "expected indirect call to have 2 trailing types"); 1360 1361 auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val()); 1362 if (!funcType) 1363 return parser.emitError(trailingTypesLoc, 1364 "expected trailing function type"); 1365 if (funcType.getNumResults() > 1) 1366 return parser.emitError(trailingTypesLoc, 1367 "expected function with 0 or 1 result"); 1368 if (funcType.getNumResults() == 1 && 1369 llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0))) 1370 return parser.emitError(trailingTypesLoc, 1371 "expected a non-void result type"); 1372 1373 // The head element of the types list matches the callee type for 1374 // indirect calls, while the types list is emtpy for direct calls. 1375 // Append the function input types to resolve the call operation 1376 // operands. 1377 llvm::append_range(types, funcType.getInputs()); 1378 if (parser.resolveOperands(operands, types, parser.getNameLoc(), 1379 result.operands)) 1380 return failure(); 1381 if (funcType.getNumResults() != 0) 1382 result.addTypes(funcType.getResults()); 1383 1384 return success(); 1385 } 1386 1387 /// Parses an optional function pointer operand before the call argument list 1388 /// for indirect calls, or stops parsing at the function identifier otherwise. 1389 static ParseResult parseOptionalCallFuncPtr( 1390 OpAsmParser &parser, 1391 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands) { 1392 OpAsmParser::UnresolvedOperand funcPtrOperand; 1393 OptionalParseResult parseResult = parser.parseOptionalOperand(funcPtrOperand); 1394 if (parseResult.has_value()) { 1395 if (failed(*parseResult)) 1396 return *parseResult; 1397 operands.push_back(funcPtrOperand); 1398 } 1399 return success(); 1400 } 1401 1402 static ParseResult resolveOpBundleOperands( 1403 OpAsmParser &parser, SMLoc loc, OperationState &state, 1404 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands, 1405 ArrayRef<SmallVector<Type>> opBundleOperandTypes, 1406 StringAttr opBundleSizesAttrName) { 1407 unsigned opBundleIndex = 0; 1408 for (const auto &[operands, types] : 1409 llvm::zip_equal(opBundleOperands, opBundleOperandTypes)) { 1410 if (operands.size() != types.size()) 1411 return parser.emitError(loc, "expected ") 1412 << operands.size() 1413 << " types for operand bundle operands for operand bundle #" 1414 << opBundleIndex << ", but actually got " << types.size(); 1415 if (parser.resolveOperands(operands, types, loc, state.operands)) 1416 return failure(); 1417 } 1418 1419 SmallVector<int32_t> opBundleSizes; 1420 opBundleSizes.reserve(opBundleOperands.size()); 1421 for (const auto &operands : opBundleOperands) 1422 opBundleSizes.push_back(operands.size()); 1423 1424 state.addAttribute( 1425 opBundleSizesAttrName, 1426 DenseI32ArrayAttr::get(parser.getContext(), opBundleSizes)); 1427 1428 return success(); 1429 } 1430 1431 // <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use) 1432 // `(` ssa-use-list `)` 1433 // ( `vararg(` var-callee-type `)` )? 1434 // ( `[` op-bundles-list `]` )? 1435 // attribute-dict? `:` (type `,`)? function-type 1436 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { 1437 SymbolRefAttr funcAttr; 1438 TypeAttr varCalleeType; 1439 SmallVector<OpAsmParser::UnresolvedOperand> operands; 1440 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands; 1441 SmallVector<SmallVector<Type>> opBundleOperandTypes; 1442 ArrayAttr opBundleTags; 1443 1444 // Default to C Calling Convention if no keyword is provided. 1445 result.addAttribute( 1446 getCConvAttrName(result.name), 1447 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( 1448 parser, result, LLVM::CConv::C))); 1449 1450 result.addAttribute( 1451 getTailCallKindAttrName(result.name), 1452 TailCallKindAttr::get(parser.getContext(), 1453 parseOptionalLLVMKeyword<TailCallKind>( 1454 parser, result, LLVM::TailCallKind::None))); 1455 1456 // Parse a function pointer for indirect calls. 1457 if (parseOptionalCallFuncPtr(parser, operands)) 1458 return failure(); 1459 bool isDirect = operands.empty(); 1460 1461 // Parse a function identifier for direct calls. 1462 if (isDirect) 1463 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 1464 return failure(); 1465 1466 // Parse the function arguments. 1467 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) 1468 return failure(); 1469 1470 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded(); 1471 if (isVarArg) { 1472 StringAttr varCalleeTypeAttrName = 1473 CallOp::getVarCalleeTypeAttrName(result.name); 1474 if (parser.parseLParen().failed() || 1475 parser 1476 .parseAttribute(varCalleeType, varCalleeTypeAttrName, 1477 result.attributes) 1478 .failed() || 1479 parser.parseRParen().failed()) 1480 return failure(); 1481 } 1482 1483 SMLoc opBundlesLoc = parser.getCurrentLocation(); 1484 if (std::optional<ParseResult> result = parseOpBundles( 1485 parser, opBundleOperands, opBundleOperandTypes, opBundleTags); 1486 result && failed(*result)) 1487 return failure(); 1488 if (opBundleTags && !opBundleTags.empty()) 1489 result.addAttribute(CallOp::getOpBundleTagsAttrName(result.name).getValue(), 1490 opBundleTags); 1491 1492 if (parser.parseOptionalAttrDict(result.attributes)) 1493 return failure(); 1494 1495 // Parse the trailing type list and resolve the operands. 1496 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands)) 1497 return failure(); 1498 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, 1499 opBundleOperandTypes, 1500 getOpBundleSizesAttrName(result.name))) 1501 return failure(); 1502 1503 int32_t numOpBundleOperands = 0; 1504 for (const auto &operands : opBundleOperands) 1505 numOpBundleOperands += operands.size(); 1506 1507 result.addAttribute( 1508 CallOp::getOperandSegmentSizeAttr(), 1509 parser.getBuilder().getDenseI32ArrayAttr( 1510 {static_cast<int32_t>(operands.size()), numOpBundleOperands})); 1511 return success(); 1512 } 1513 1514 LLVMFunctionType CallOp::getCalleeFunctionType() { 1515 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType()) 1516 return *varCalleeType; 1517 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands()); 1518 } 1519 1520 ///===---------------------------------------------------------------------===// 1521 /// LLVM::InvokeOp 1522 ///===---------------------------------------------------------------------===// 1523 1524 void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, 1525 ValueRange ops, Block *normal, ValueRange normalOps, 1526 Block *unwind, ValueRange unwindOps) { 1527 auto calleeType = func.getFunctionType(); 1528 build(builder, state, getCallOpResultTypes(calleeType), 1529 getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops, 1530 normalOps, unwindOps, nullptr, nullptr, {}, {}, normal, unwind); 1531 } 1532 1533 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, 1534 FlatSymbolRefAttr callee, ValueRange ops, Block *normal, 1535 ValueRange normalOps, Block *unwind, 1536 ValueRange unwindOps) { 1537 build(builder, state, tys, 1538 /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr, 1539 nullptr, {}, {}, normal, unwind); 1540 } 1541 1542 void InvokeOp::build(OpBuilder &builder, OperationState &state, 1543 LLVMFunctionType calleeType, FlatSymbolRefAttr callee, 1544 ValueRange ops, Block *normal, ValueRange normalOps, 1545 Block *unwind, ValueRange unwindOps) { 1546 build(builder, state, getCallOpResultTypes(calleeType), 1547 getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps, 1548 nullptr, nullptr, {}, {}, normal, unwind); 1549 } 1550 1551 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { 1552 assert(index < getNumSuccessors() && "invalid successor index"); 1553 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() 1554 : getUnwindDestOperandsMutable()); 1555 } 1556 1557 CallInterfaceCallable InvokeOp::getCallableForCallee() { 1558 // Direct call. 1559 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) 1560 return calleeAttr; 1561 // Indirect call, callee Value is the first operand. 1562 return getOperand(0); 1563 } 1564 1565 void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) { 1566 // Direct call. 1567 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) { 1568 auto symRef = cast<SymbolRefAttr>(callee); 1569 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef)); 1570 } 1571 // Indirect call, callee Value is the first operand. 1572 return setOperand(0, cast<Value>(callee)); 1573 } 1574 1575 Operation::operand_range InvokeOp::getArgOperands() { 1576 return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1); 1577 } 1578 1579 MutableOperandRange InvokeOp::getArgOperandsMutable() { 1580 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1, 1581 getCalleeOperands().size()); 1582 } 1583 1584 LogicalResult InvokeOp::verify() { 1585 if (failed(verifyCallOpVarCalleeType(*this))) 1586 return failure(); 1587 1588 Block *unwindDest = getUnwindDest(); 1589 if (unwindDest->empty()) 1590 return emitError("must have at least one operation in unwind destination"); 1591 1592 // In unwind destination, first operation must be LandingpadOp 1593 if (!isa<LandingpadOp>(unwindDest->front())) 1594 return emitError("first operation in unwind destination should be a " 1595 "llvm.landingpad operation"); 1596 1597 if (failed(verifyOperandBundles(*this))) 1598 return failure(); 1599 1600 return success(); 1601 } 1602 1603 void InvokeOp::print(OpAsmPrinter &p) { 1604 auto callee = getCallee(); 1605 bool isDirect = callee.has_value(); 1606 1607 p << ' '; 1608 1609 // Print calling convention. 1610 if (getCConv() != LLVM::CConv::C) 1611 p << stringifyCConv(getCConv()) << ' '; 1612 1613 // Either function name or pointer 1614 if (isDirect) 1615 p.printSymbolName(callee.value()); 1616 else 1617 p << getOperand(0); 1618 1619 p << '(' << getCalleeOperands().drop_front(isDirect ? 0 : 1) << ')'; 1620 p << " to "; 1621 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); 1622 p << " unwind "; 1623 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); 1624 1625 // Print the variadic callee type if the invoke is variadic. 1626 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType()) 1627 p << " vararg(" << *varCalleeType << ")"; 1628 1629 if (!getOpBundleOperands().empty()) { 1630 p << " "; 1631 printOpBundles(p, *this, getOpBundleOperands(), 1632 getOpBundleOperands().getTypes(), getOpBundleTags()); 1633 } 1634 1635 p.printOptionalAttrDict((*this)->getAttrs(), 1636 {getCalleeAttrName(), getOperandSegmentSizeAttr(), 1637 getCConvAttrName(), getVarCalleeTypeAttrName(), 1638 getOpBundleSizesAttrName(), 1639 getOpBundleTagsAttrName()}); 1640 1641 p << " : "; 1642 if (!isDirect) 1643 p << getOperand(0).getType() << ", "; 1644 p.printFunctionalType( 1645 llvm::drop_begin(getCalleeOperands().getTypes(), isDirect ? 0 : 1), 1646 getResultTypes()); 1647 } 1648 1649 // <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use) 1650 // `(` ssa-use-list `)` 1651 // `to` bb-id (`[` ssa-use-and-type-list `]`)? 1652 // `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 1653 // ( `vararg(` var-callee-type `)` )? 1654 // ( `[` op-bundles-list `]` )? 1655 // attribute-dict? `:` (type `,`)? function-type 1656 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { 1657 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 1658 SymbolRefAttr funcAttr; 1659 TypeAttr varCalleeType; 1660 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands; 1661 SmallVector<SmallVector<Type>> opBundleOperandTypes; 1662 ArrayAttr opBundleTags; 1663 Block *normalDest, *unwindDest; 1664 SmallVector<Value, 4> normalOperands, unwindOperands; 1665 Builder &builder = parser.getBuilder(); 1666 1667 // Default to C Calling Convention if no keyword is provided. 1668 result.addAttribute( 1669 getCConvAttrName(result.name), 1670 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( 1671 parser, result, LLVM::CConv::C))); 1672 1673 // Parse a function pointer for indirect calls. 1674 if (parseOptionalCallFuncPtr(parser, operands)) 1675 return failure(); 1676 bool isDirect = operands.empty(); 1677 1678 // Parse a function identifier for direct calls. 1679 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 1680 return failure(); 1681 1682 // Parse the function arguments. 1683 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 1684 parser.parseKeyword("to") || 1685 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 1686 parser.parseKeyword("unwind") || 1687 parser.parseSuccessorAndUseList(unwindDest, unwindOperands)) 1688 return failure(); 1689 1690 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded(); 1691 if (isVarArg) { 1692 StringAttr varCalleeTypeAttrName = 1693 InvokeOp::getVarCalleeTypeAttrName(result.name); 1694 if (parser.parseLParen().failed() || 1695 parser 1696 .parseAttribute(varCalleeType, varCalleeTypeAttrName, 1697 result.attributes) 1698 .failed() || 1699 parser.parseRParen().failed()) 1700 return failure(); 1701 } 1702 1703 SMLoc opBundlesLoc = parser.getCurrentLocation(); 1704 if (std::optional<ParseResult> result = parseOpBundles( 1705 parser, opBundleOperands, opBundleOperandTypes, opBundleTags); 1706 result && failed(*result)) 1707 return failure(); 1708 if (opBundleTags && !opBundleTags.empty()) 1709 result.addAttribute( 1710 InvokeOp::getOpBundleTagsAttrName(result.name).getValue(), 1711 opBundleTags); 1712 1713 if (parser.parseOptionalAttrDict(result.attributes)) 1714 return failure(); 1715 1716 // Parse the trailing type list and resolve the function operands. 1717 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands)) 1718 return failure(); 1719 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, 1720 opBundleOperandTypes, 1721 getOpBundleSizesAttrName(result.name))) 1722 return failure(); 1723 1724 result.addSuccessors({normalDest, unwindDest}); 1725 result.addOperands(normalOperands); 1726 result.addOperands(unwindOperands); 1727 1728 int32_t numOpBundleOperands = 0; 1729 for (const auto &operands : opBundleOperands) 1730 numOpBundleOperands += operands.size(); 1731 1732 result.addAttribute( 1733 InvokeOp::getOperandSegmentSizeAttr(), 1734 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operands.size()), 1735 static_cast<int32_t>(normalOperands.size()), 1736 static_cast<int32_t>(unwindOperands.size()), 1737 numOpBundleOperands})); 1738 return success(); 1739 } 1740 1741 LLVMFunctionType InvokeOp::getCalleeFunctionType() { 1742 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType()) 1743 return *varCalleeType; 1744 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands()); 1745 } 1746 1747 ///===----------------------------------------------------------------------===// 1748 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 1749 ///===----------------------------------------------------------------------===// 1750 1751 LogicalResult LandingpadOp::verify() { 1752 Value value; 1753 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { 1754 if (!func.getPersonality()) 1755 return emitError( 1756 "llvm.landingpad needs to be in a function with a personality"); 1757 } 1758 1759 // Consistency of llvm.landingpad result types is checked in 1760 // LLVMFuncOp::verify(). 1761 1762 if (!getCleanup() && getOperands().empty()) 1763 return emitError("landingpad instruction expects at least one clause or " 1764 "cleanup attribute"); 1765 1766 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { 1767 value = getOperand(idx); 1768 bool isFilter = llvm::isa<LLVMArrayType>(value.getType()); 1769 if (isFilter) { 1770 // FIXME: Verify filter clauses when arrays are appropriately handled 1771 } else { 1772 // catch - global addresses only. 1773 // Bitcast ops should have global addresses as their args. 1774 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 1775 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 1776 continue; 1777 return emitError("constant clauses expected").attachNote(bcOp.getLoc()) 1778 << "global addresses expected as operand to " 1779 "bitcast used in clauses for landingpad"; 1780 } 1781 // ZeroOp and AddressOfOp allowed 1782 if (value.getDefiningOp<ZeroOp>()) 1783 continue; 1784 if (value.getDefiningOp<AddressOfOp>()) 1785 continue; 1786 return emitError("clause #") 1787 << idx << " is not a known constant - null, addressof, bitcast"; 1788 } 1789 } 1790 return success(); 1791 } 1792 1793 void LandingpadOp::print(OpAsmPrinter &p) { 1794 p << (getCleanup() ? " cleanup " : " "); 1795 1796 // Clauses 1797 for (auto value : getOperands()) { 1798 // Similar to llvm - if clause is an array type then it is filter 1799 // clause else catch clause 1800 bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType()); 1801 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 1802 << value.getType() << ") "; 1803 } 1804 1805 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); 1806 1807 p << ": " << getType(); 1808 } 1809 1810 // <operation> ::= `llvm.landingpad` `cleanup`? 1811 // ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 1812 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { 1813 // Check for cleanup 1814 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 1815 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 1816 1817 // Parse clauses with types 1818 while (succeeded(parser.parseOptionalLParen()) && 1819 (succeeded(parser.parseOptionalKeyword("filter")) || 1820 succeeded(parser.parseOptionalKeyword("catch")))) { 1821 OpAsmParser::UnresolvedOperand operand; 1822 Type ty; 1823 if (parser.parseOperand(operand) || parser.parseColon() || 1824 parser.parseType(ty) || 1825 parser.resolveOperand(operand, ty, result.operands) || 1826 parser.parseRParen()) 1827 return failure(); 1828 } 1829 1830 Type type; 1831 if (parser.parseColon() || parser.parseType(type)) 1832 return failure(); 1833 1834 result.addTypes(type); 1835 return success(); 1836 } 1837 1838 //===----------------------------------------------------------------------===// 1839 // ExtractValueOp 1840 //===----------------------------------------------------------------------===// 1841 1842 /// Extract the type at `position` in the LLVM IR aggregate type 1843 /// `containerType`. Each element of `position` is an index into a nested 1844 /// aggregate type. Return the resulting type or emit an error. 1845 static Type getInsertExtractValueElementType( 1846 function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType, 1847 ArrayRef<int64_t> position) { 1848 Type llvmType = containerType; 1849 if (!isCompatibleType(containerType)) { 1850 emitError("expected LLVM IR Dialect type, got ") << containerType; 1851 return {}; 1852 } 1853 1854 // Infer the element type from the structure type: iteratively step inside the 1855 // type by taking the element type, indexed by the position attribute for 1856 // structures. Check the position index before accessing, it is supposed to 1857 // be in bounds. 1858 for (int64_t idx : position) { 1859 if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) { 1860 if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) { 1861 emitError("position out of bounds: ") << idx; 1862 return {}; 1863 } 1864 llvmType = arrayType.getElementType(); 1865 } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) { 1866 if (idx < 0 || 1867 static_cast<unsigned>(idx) >= structType.getBody().size()) { 1868 emitError("position out of bounds: ") << idx; 1869 return {}; 1870 } 1871 llvmType = structType.getBody()[idx]; 1872 } else { 1873 emitError("expected LLVM IR structure/array type, got: ") << llvmType; 1874 return {}; 1875 } 1876 } 1877 return llvmType; 1878 } 1879 1880 /// Extract the type at `position` in the wrapped LLVM IR aggregate type 1881 /// `containerType`. 1882 static Type getInsertExtractValueElementType(Type llvmType, 1883 ArrayRef<int64_t> position) { 1884 for (int64_t idx : position) { 1885 if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) 1886 llvmType = structType.getBody()[idx]; 1887 else 1888 llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType(); 1889 } 1890 return llvmType; 1891 } 1892 1893 OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) { 1894 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1895 OpFoldResult result = {}; 1896 while (insertValueOp) { 1897 if (getPosition() == insertValueOp.getPosition()) 1898 return insertValueOp.getValue(); 1899 unsigned min = 1900 std::min(getPosition().size(), insertValueOp.getPosition().size()); 1901 // If one is fully prefix of the other, stop propagating back as it will 1902 // miss dependencies. For instance, %3 should not fold to %f0 in the 1903 // following example: 1904 // ``` 1905 // %1 = llvm.insertvalue %f0, %0[0, 0] : 1906 // !llvm.array<4 x !llvm.array<4 x f32>> 1907 // %2 = llvm.insertvalue %arr, %1[0] : 1908 // !llvm.array<4 x !llvm.array<4 x f32>> 1909 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>> 1910 // ``` 1911 if (getPosition().take_front(min) == 1912 insertValueOp.getPosition().take_front(min)) 1913 return result; 1914 1915 // If neither a prefix, nor the exact position, we can extract out of the 1916 // value being inserted into. Moreover, we can try again if that operand 1917 // is itself an insertvalue expression. 1918 getContainerMutable().assign(insertValueOp.getContainer()); 1919 result = getResult(); 1920 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1921 } 1922 return result; 1923 } 1924 1925 LogicalResult ExtractValueOp::verify() { 1926 auto emitError = [this](StringRef msg) { return emitOpError(msg); }; 1927 Type valueType = getInsertExtractValueElementType( 1928 emitError, getContainer().getType(), getPosition()); 1929 if (!valueType) 1930 return failure(); 1931 1932 if (getRes().getType() != valueType) 1933 return emitOpError() << "Type mismatch: extracting from " 1934 << getContainer().getType() << " should produce " 1935 << valueType << " but this op returns " 1936 << getRes().getType(); 1937 return success(); 1938 } 1939 1940 void ExtractValueOp::build(OpBuilder &builder, OperationState &state, 1941 Value container, ArrayRef<int64_t> position) { 1942 build(builder, state, 1943 getInsertExtractValueElementType(container.getType(), position), 1944 container, builder.getAttr<DenseI64ArrayAttr>(position)); 1945 } 1946 1947 //===----------------------------------------------------------------------===// 1948 // InsertValueOp 1949 //===----------------------------------------------------------------------===// 1950 1951 /// Infer the value type from the container type and position. 1952 static ParseResult 1953 parseInsertExtractValueElementType(AsmParser &parser, Type &valueType, 1954 Type containerType, 1955 DenseI64ArrayAttr position) { 1956 valueType = getInsertExtractValueElementType( 1957 [&](StringRef msg) { 1958 return parser.emitError(parser.getCurrentLocation(), msg); 1959 }, 1960 containerType, position.asArrayRef()); 1961 return success(!!valueType); 1962 } 1963 1964 /// Nothing to print for an inferred type. 1965 static void printInsertExtractValueElementType(AsmPrinter &printer, 1966 Operation *op, Type valueType, 1967 Type containerType, 1968 DenseI64ArrayAttr position) {} 1969 1970 LogicalResult InsertValueOp::verify() { 1971 auto emitError = [this](StringRef msg) { return emitOpError(msg); }; 1972 Type valueType = getInsertExtractValueElementType( 1973 emitError, getContainer().getType(), getPosition()); 1974 if (!valueType) 1975 return failure(); 1976 1977 if (getValue().getType() != valueType) 1978 return emitOpError() << "Type mismatch: cannot insert " 1979 << getValue().getType() << " into " 1980 << getContainer().getType(); 1981 1982 return success(); 1983 } 1984 1985 //===----------------------------------------------------------------------===// 1986 // ReturnOp 1987 //===----------------------------------------------------------------------===// 1988 1989 LogicalResult ReturnOp::verify() { 1990 auto parent = (*this)->getParentOfType<LLVMFuncOp>(); 1991 if (!parent) 1992 return success(); 1993 1994 Type expectedType = parent.getFunctionType().getReturnType(); 1995 if (llvm::isa<LLVMVoidType>(expectedType)) { 1996 if (!getArg()) 1997 return success(); 1998 InFlightDiagnostic diag = emitOpError("expected no operands"); 1999 diag.attachNote(parent->getLoc()) << "when returning from function"; 2000 return diag; 2001 } 2002 if (!getArg()) { 2003 if (llvm::isa<LLVMVoidType>(expectedType)) 2004 return success(); 2005 InFlightDiagnostic diag = emitOpError("expected 1 operand"); 2006 diag.attachNote(parent->getLoc()) << "when returning from function"; 2007 return diag; 2008 } 2009 if (expectedType != getArg().getType()) { 2010 InFlightDiagnostic diag = emitOpError("mismatching result types"); 2011 diag.attachNote(parent->getLoc()) << "when returning from function"; 2012 return diag; 2013 } 2014 return success(); 2015 } 2016 2017 //===----------------------------------------------------------------------===// 2018 // LLVM::AddressOfOp. 2019 //===----------------------------------------------------------------------===// 2020 2021 static Operation *parentLLVMModule(Operation *op) { 2022 Operation *module = op->getParentOp(); 2023 while (module && !satisfiesLLVMModule(module)) 2024 module = module->getParentOp(); 2025 assert(module && "unexpected operation outside of a module"); 2026 return module; 2027 } 2028 2029 GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) { 2030 return dyn_cast_or_null<GlobalOp>( 2031 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); 2032 } 2033 2034 LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) { 2035 return dyn_cast_or_null<LLVMFuncOp>( 2036 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); 2037 } 2038 2039 LogicalResult 2040 AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2041 Operation *symbol = 2042 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()); 2043 2044 auto global = dyn_cast_or_null<GlobalOp>(symbol); 2045 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol); 2046 2047 if (!global && !function) 2048 return emitOpError( 2049 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 2050 2051 LLVMPointerType type = getType(); 2052 if (global && global.getAddrSpace() != type.getAddressSpace()) 2053 return emitOpError("pointer address space must match address space of the " 2054 "referenced global"); 2055 2056 return success(); 2057 } 2058 2059 // AddressOfOp constant-folds to the global symbol name. 2060 OpFoldResult LLVM::AddressOfOp::fold(FoldAdaptor) { 2061 return getGlobalNameAttr(); 2062 } 2063 2064 //===----------------------------------------------------------------------===// 2065 // Verifier for LLVM::ComdatOp. 2066 //===----------------------------------------------------------------------===// 2067 2068 void ComdatOp::build(OpBuilder &builder, OperationState &result, 2069 StringRef symName) { 2070 result.addAttribute(getSymNameAttrName(result.name), 2071 builder.getStringAttr(symName)); 2072 Region *body = result.addRegion(); 2073 body->emplaceBlock(); 2074 } 2075 2076 LogicalResult ComdatOp::verifyRegions() { 2077 Region &body = getBody(); 2078 for (Operation &op : body.getOps()) 2079 if (!isa<ComdatSelectorOp>(op)) 2080 return op.emitError( 2081 "only comdat selector symbols can appear in a comdat region"); 2082 2083 return success(); 2084 } 2085 2086 //===----------------------------------------------------------------------===// 2087 // Builder, printer and verifier for LLVM::GlobalOp. 2088 //===----------------------------------------------------------------------===// 2089 2090 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 2091 bool isConstant, Linkage linkage, StringRef name, 2092 Attribute value, uint64_t alignment, unsigned addrSpace, 2093 bool dsoLocal, bool threadLocal, SymbolRefAttr comdat, 2094 ArrayRef<NamedAttribute> attrs, 2095 ArrayRef<Attribute> dbgExprs) { 2096 result.addAttribute(getSymNameAttrName(result.name), 2097 builder.getStringAttr(name)); 2098 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type)); 2099 if (isConstant) 2100 result.addAttribute(getConstantAttrName(result.name), 2101 builder.getUnitAttr()); 2102 if (value) 2103 result.addAttribute(getValueAttrName(result.name), value); 2104 if (dsoLocal) 2105 result.addAttribute(getDsoLocalAttrName(result.name), 2106 builder.getUnitAttr()); 2107 if (threadLocal) 2108 result.addAttribute(getThreadLocal_AttrName(result.name), 2109 builder.getUnitAttr()); 2110 if (comdat) 2111 result.addAttribute(getComdatAttrName(result.name), comdat); 2112 2113 // Only add an alignment attribute if the "alignment" input 2114 // is different from 0. The value must also be a power of two, but 2115 // this is tested in GlobalOp::verify, not here. 2116 if (alignment != 0) 2117 result.addAttribute(getAlignmentAttrName(result.name), 2118 builder.getI64IntegerAttr(alignment)); 2119 2120 result.addAttribute(getLinkageAttrName(result.name), 2121 LinkageAttr::get(builder.getContext(), linkage)); 2122 if (addrSpace != 0) 2123 result.addAttribute(getAddrSpaceAttrName(result.name), 2124 builder.getI32IntegerAttr(addrSpace)); 2125 result.attributes.append(attrs.begin(), attrs.end()); 2126 2127 if (!dbgExprs.empty()) 2128 result.addAttribute(getDbgExprsAttrName(result.name), 2129 ArrayAttr::get(builder.getContext(), dbgExprs)); 2130 2131 result.addRegion(); 2132 } 2133 2134 void GlobalOp::print(OpAsmPrinter &p) { 2135 p << ' ' << stringifyLinkage(getLinkage()) << ' '; 2136 StringRef visibility = stringifyVisibility(getVisibility_()); 2137 if (!visibility.empty()) 2138 p << visibility << ' '; 2139 if (getThreadLocal_()) 2140 p << "thread_local "; 2141 if (auto unnamedAddr = getUnnamedAddr()) { 2142 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 2143 if (!str.empty()) 2144 p << str << ' '; 2145 } 2146 if (getConstant()) 2147 p << "constant "; 2148 p.printSymbolName(getSymName()); 2149 p << '('; 2150 if (auto value = getValueOrNull()) 2151 p.printAttribute(value); 2152 p << ')'; 2153 if (auto comdat = getComdat()) 2154 p << " comdat(" << *comdat << ')'; 2155 2156 // Note that the alignment attribute is printed using the 2157 // default syntax here, even though it is an inherent attribute 2158 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 2159 p.printOptionalAttrDict((*this)->getAttrs(), 2160 {SymbolTable::getSymbolAttrName(), 2161 getGlobalTypeAttrName(), getConstantAttrName(), 2162 getValueAttrName(), getLinkageAttrName(), 2163 getUnnamedAddrAttrName(), getThreadLocal_AttrName(), 2164 getVisibility_AttrName(), getComdatAttrName(), 2165 getUnnamedAddrAttrName()}); 2166 2167 // Print the trailing type unless it's a string global. 2168 if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) 2169 return; 2170 p << " : " << getType(); 2171 2172 Region &initializer = getInitializerRegion(); 2173 if (!initializer.empty()) { 2174 p << ' '; 2175 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 2176 } 2177 } 2178 2179 static LogicalResult verifyComdat(Operation *op, 2180 std::optional<SymbolRefAttr> attr) { 2181 if (!attr) 2182 return success(); 2183 2184 auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr); 2185 if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector)) 2186 return op->emitError() << "expected comdat symbol"; 2187 2188 return success(); 2189 } 2190 2191 // operation ::= `llvm.mlir.global` linkage? visibility? 2192 // (`unnamed_addr` | `local_unnamed_addr`)? 2193 // `thread_local`? `constant`? `@` identifier 2194 // `(` attribute? `)` (`comdat(` symbol-ref-id `)`)? 2195 // attribute-list? (`:` type)? region? 2196 // 2197 // The type can be omitted for string attributes, in which case it will be 2198 // inferred from the value of the string as [strlen(value) x i8]. 2199 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { 2200 MLIRContext *ctx = parser.getContext(); 2201 // Parse optional linkage, default to External. 2202 result.addAttribute(getLinkageAttrName(result.name), 2203 LLVM::LinkageAttr::get( 2204 ctx, parseOptionalLLVMKeyword<Linkage>( 2205 parser, result, LLVM::Linkage::External))); 2206 2207 // Parse optional visibility, default to Default. 2208 result.addAttribute(getVisibility_AttrName(result.name), 2209 parser.getBuilder().getI64IntegerAttr( 2210 parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>( 2211 parser, result, LLVM::Visibility::Default))); 2212 2213 // Parse optional UnnamedAddr, default to None. 2214 result.addAttribute(getUnnamedAddrAttrName(result.name), 2215 parser.getBuilder().getI64IntegerAttr( 2216 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 2217 parser, result, LLVM::UnnamedAddr::None))); 2218 2219 if (succeeded(parser.parseOptionalKeyword("thread_local"))) 2220 result.addAttribute(getThreadLocal_AttrName(result.name), 2221 parser.getBuilder().getUnitAttr()); 2222 2223 if (succeeded(parser.parseOptionalKeyword("constant"))) 2224 result.addAttribute(getConstantAttrName(result.name), 2225 parser.getBuilder().getUnitAttr()); 2226 2227 StringAttr name; 2228 if (parser.parseSymbolName(name, getSymNameAttrName(result.name), 2229 result.attributes) || 2230 parser.parseLParen()) 2231 return failure(); 2232 2233 Attribute value; 2234 if (parser.parseOptionalRParen()) { 2235 if (parser.parseAttribute(value, getValueAttrName(result.name), 2236 result.attributes) || 2237 parser.parseRParen()) 2238 return failure(); 2239 } 2240 2241 if (succeeded(parser.parseOptionalKeyword("comdat"))) { 2242 SymbolRefAttr comdat; 2243 if (parser.parseLParen() || parser.parseAttribute(comdat) || 2244 parser.parseRParen()) 2245 return failure(); 2246 2247 result.addAttribute(getComdatAttrName(result.name), comdat); 2248 } 2249 2250 SmallVector<Type, 1> types; 2251 if (parser.parseOptionalAttrDict(result.attributes) || 2252 parser.parseOptionalColonTypeList(types)) 2253 return failure(); 2254 2255 if (types.size() > 1) 2256 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 2257 2258 Region &initRegion = *result.addRegion(); 2259 if (types.empty()) { 2260 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) { 2261 MLIRContext *context = parser.getContext(); 2262 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 2263 strAttr.getValue().size()); 2264 types.push_back(arrayType); 2265 } else { 2266 return parser.emitError(parser.getNameLoc(), 2267 "type can only be omitted for string globals"); 2268 } 2269 } else { 2270 OptionalParseResult parseResult = 2271 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 2272 /*argTypes=*/{}); 2273 if (parseResult.has_value() && failed(*parseResult)) 2274 return failure(); 2275 } 2276 2277 result.addAttribute(getGlobalTypeAttrName(result.name), 2278 TypeAttr::get(types[0])); 2279 return success(); 2280 } 2281 2282 static bool isZeroAttribute(Attribute value) { 2283 if (auto intValue = llvm::dyn_cast<IntegerAttr>(value)) 2284 return intValue.getValue().isZero(); 2285 if (auto fpValue = llvm::dyn_cast<FloatAttr>(value)) 2286 return fpValue.getValue().isZero(); 2287 if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(value)) 2288 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 2289 if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value)) 2290 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 2291 if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value)) 2292 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 2293 return false; 2294 } 2295 2296 LogicalResult GlobalOp::verify() { 2297 bool validType = isCompatibleOuterType(getType()) 2298 ? !llvm::isa<LLVMVoidType, LLVMTokenType, 2299 LLVMMetadataType, LLVMLabelType>(getType()) 2300 : llvm::isa<PointerElementTypeInterface>(getType()); 2301 if (!validType) 2302 return emitOpError( 2303 "expects type to be a valid element type for an LLVM global"); 2304 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) 2305 return emitOpError("must appear at the module level"); 2306 2307 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) { 2308 auto type = llvm::dyn_cast<LLVMArrayType>(getType()); 2309 IntegerType elementType = 2310 type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr; 2311 if (!elementType || elementType.getWidth() != 8 || 2312 type.getNumElements() != strAttr.getValue().size()) 2313 return emitOpError( 2314 "requires an i8 array type of the length equal to that of the string " 2315 "attribute"); 2316 } 2317 2318 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) { 2319 if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal)) 2320 return emitOpError() 2321 << "this target extension type cannot be used in a global"; 2322 2323 if (Attribute value = getValueOrNull()) 2324 return emitOpError() << "global with target extension type can only be " 2325 "initialized with zero-initializer"; 2326 } 2327 2328 if (getLinkage() == Linkage::Common) { 2329 if (Attribute value = getValueOrNull()) { 2330 if (!isZeroAttribute(value)) { 2331 return emitOpError() 2332 << "expected zero value for '" 2333 << stringifyLinkage(Linkage::Common) << "' linkage"; 2334 } 2335 } 2336 } 2337 2338 if (getLinkage() == Linkage::Appending) { 2339 if (!llvm::isa<LLVMArrayType>(getType())) { 2340 return emitOpError() << "expected array type for '" 2341 << stringifyLinkage(Linkage::Appending) 2342 << "' linkage"; 2343 } 2344 } 2345 2346 if (failed(verifyComdat(*this, getComdat()))) 2347 return failure(); 2348 2349 std::optional<uint64_t> alignAttr = getAlignment(); 2350 if (alignAttr.has_value()) { 2351 uint64_t value = alignAttr.value(); 2352 if (!llvm::isPowerOf2_64(value)) 2353 return emitError() << "alignment attribute is not a power of 2"; 2354 } 2355 2356 return success(); 2357 } 2358 2359 LogicalResult GlobalOp::verifyRegions() { 2360 if (Block *b = getInitializerBlock()) { 2361 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 2362 if (ret.operand_type_begin() == ret.operand_type_end()) 2363 return emitOpError("initializer region cannot return void"); 2364 if (*ret.operand_type_begin() != getType()) 2365 return emitOpError("initializer region type ") 2366 << *ret.operand_type_begin() << " does not match global type " 2367 << getType(); 2368 2369 for (Operation &op : *b) { 2370 auto iface = dyn_cast<MemoryEffectOpInterface>(op); 2371 if (!iface || !iface.hasNoEffect()) 2372 return op.emitError() 2373 << "ops with side effects not allowed in global initializers"; 2374 } 2375 2376 if (getValueOrNull()) 2377 return emitOpError("cannot have both initializer value and region"); 2378 } 2379 2380 return success(); 2381 } 2382 2383 //===----------------------------------------------------------------------===// 2384 // LLVM::GlobalCtorsOp 2385 //===----------------------------------------------------------------------===// 2386 2387 LogicalResult 2388 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2389 for (Attribute ctor : getCtors()) { 2390 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this, 2391 symbolTable))) 2392 return failure(); 2393 } 2394 return success(); 2395 } 2396 2397 LogicalResult GlobalCtorsOp::verify() { 2398 if (getCtors().size() != getPriorities().size()) 2399 return emitError( 2400 "mismatch between the number of ctors and the number of priorities"); 2401 return success(); 2402 } 2403 2404 //===----------------------------------------------------------------------===// 2405 // LLVM::GlobalDtorsOp 2406 //===----------------------------------------------------------------------===// 2407 2408 LogicalResult 2409 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2410 for (Attribute dtor : getDtors()) { 2411 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this, 2412 symbolTable))) 2413 return failure(); 2414 } 2415 return success(); 2416 } 2417 2418 LogicalResult GlobalDtorsOp::verify() { 2419 if (getDtors().size() != getPriorities().size()) 2420 return emitError( 2421 "mismatch between the number of dtors and the number of priorities"); 2422 return success(); 2423 } 2424 2425 //===----------------------------------------------------------------------===// 2426 // ShuffleVectorOp 2427 //===----------------------------------------------------------------------===// 2428 2429 void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1, 2430 Value v2, DenseI32ArrayAttr mask, 2431 ArrayRef<NamedAttribute> attrs) { 2432 auto containerType = v1.getType(); 2433 auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType), 2434 mask.size(), 2435 LLVM::isScalableVectorType(containerType)); 2436 build(builder, state, vType, v1, v2, mask); 2437 state.addAttributes(attrs); 2438 } 2439 2440 void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1, 2441 Value v2, ArrayRef<int32_t> mask) { 2442 build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask)); 2443 } 2444 2445 /// Build the result type of a shuffle vector operation. 2446 static ParseResult parseShuffleType(AsmParser &parser, Type v1Type, 2447 Type &resType, DenseI32ArrayAttr mask) { 2448 if (!LLVM::isCompatibleVectorType(v1Type)) 2449 return parser.emitError(parser.getCurrentLocation(), 2450 "expected an LLVM compatible vector type"); 2451 resType = LLVM::getVectorType(LLVM::getVectorElementType(v1Type), mask.size(), 2452 LLVM::isScalableVectorType(v1Type)); 2453 return success(); 2454 } 2455 2456 /// Nothing to do when the result type is inferred. 2457 static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type, 2458 Type resType, DenseI32ArrayAttr mask) {} 2459 2460 LogicalResult ShuffleVectorOp::verify() { 2461 if (LLVM::isScalableVectorType(getV1().getType()) && 2462 llvm::any_of(getMask(), [](int32_t v) { return v != 0; })) 2463 return emitOpError("expected a splat operation for scalable vectors"); 2464 return success(); 2465 } 2466 2467 //===----------------------------------------------------------------------===// 2468 // Implementations for LLVM::LLVMFuncOp. 2469 //===----------------------------------------------------------------------===// 2470 2471 // Add the entry block to the function. 2472 Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) { 2473 assert(empty() && "function already has an entry block"); 2474 OpBuilder::InsertionGuard g(builder); 2475 Block *entry = builder.createBlock(&getBody()); 2476 2477 // FIXME: Allow passing in proper locations for the entry arguments. 2478 LLVMFunctionType type = getFunctionType(); 2479 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 2480 entry->addArgument(type.getParamType(i), getLoc()); 2481 return entry; 2482 } 2483 2484 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 2485 StringRef name, Type type, LLVM::Linkage linkage, 2486 bool dsoLocal, CConv cconv, SymbolRefAttr comdat, 2487 ArrayRef<NamedAttribute> attrs, 2488 ArrayRef<DictionaryAttr> argAttrs, 2489 std::optional<uint64_t> functionEntryCount) { 2490 result.addRegion(); 2491 result.addAttribute(SymbolTable::getSymbolAttrName(), 2492 builder.getStringAttr(name)); 2493 result.addAttribute(getFunctionTypeAttrName(result.name), 2494 TypeAttr::get(type)); 2495 result.addAttribute(getLinkageAttrName(result.name), 2496 LinkageAttr::get(builder.getContext(), linkage)); 2497 result.addAttribute(getCConvAttrName(result.name), 2498 CConvAttr::get(builder.getContext(), cconv)); 2499 result.attributes.append(attrs.begin(), attrs.end()); 2500 if (dsoLocal) 2501 result.addAttribute(getDsoLocalAttrName(result.name), 2502 builder.getUnitAttr()); 2503 if (comdat) 2504 result.addAttribute(getComdatAttrName(result.name), comdat); 2505 if (functionEntryCount) 2506 result.addAttribute(getFunctionEntryCountAttrName(result.name), 2507 builder.getI64IntegerAttr(functionEntryCount.value())); 2508 if (argAttrs.empty()) 2509 return; 2510 2511 assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() && 2512 "expected as many argument attribute lists as arguments"); 2513 function_interface_impl::addArgAndResultAttrs( 2514 builder, result, argAttrs, /*resultAttrs=*/std::nullopt, 2515 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 2516 } 2517 2518 // Builds an LLVM function type from the given lists of input and output types. 2519 // Returns a null type if any of the types provided are non-LLVM types, or if 2520 // there is more than one output type. 2521 static Type 2522 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, 2523 ArrayRef<Type> outputs, 2524 function_interface_impl::VariadicFlag variadicFlag) { 2525 Builder &b = parser.getBuilder(); 2526 if (outputs.size() > 1) { 2527 parser.emitError(loc, "failed to construct function type: expected zero or " 2528 "one function result"); 2529 return {}; 2530 } 2531 2532 // Convert inputs to LLVM types, exit early on error. 2533 SmallVector<Type, 4> llvmInputs; 2534 for (auto t : inputs) { 2535 if (!isCompatibleType(t)) { 2536 parser.emitError(loc, "failed to construct function type: expected LLVM " 2537 "type for function arguments"); 2538 return {}; 2539 } 2540 llvmInputs.push_back(t); 2541 } 2542 2543 // No output is denoted as "void" in LLVM type system. 2544 Type llvmOutput = 2545 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 2546 if (!isCompatibleType(llvmOutput)) { 2547 parser.emitError(loc, "failed to construct function type: expected LLVM " 2548 "type for function results") 2549 << llvmOutput; 2550 return {}; 2551 } 2552 return LLVMFunctionType::get(llvmOutput, llvmInputs, 2553 variadicFlag.isVariadic()); 2554 } 2555 2556 // Parses an LLVM function. 2557 // 2558 // operation ::= `llvm.func` linkage? cconv? function-signature 2559 // (`comdat(` symbol-ref-id `)`)? 2560 // function-attributes? 2561 // function-body 2562 // 2563 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { 2564 // Default to external linkage if no keyword is provided. 2565 result.addAttribute( 2566 getLinkageAttrName(result.name), 2567 LinkageAttr::get(parser.getContext(), 2568 parseOptionalLLVMKeyword<Linkage>( 2569 parser, result, LLVM::Linkage::External))); 2570 2571 // Parse optional visibility, default to Default. 2572 result.addAttribute(getVisibility_AttrName(result.name), 2573 parser.getBuilder().getI64IntegerAttr( 2574 parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>( 2575 parser, result, LLVM::Visibility::Default))); 2576 2577 // Parse optional UnnamedAddr, default to None. 2578 result.addAttribute(getUnnamedAddrAttrName(result.name), 2579 parser.getBuilder().getI64IntegerAttr( 2580 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 2581 parser, result, LLVM::UnnamedAddr::None))); 2582 2583 // Default to C Calling Convention if no keyword is provided. 2584 result.addAttribute( 2585 getCConvAttrName(result.name), 2586 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( 2587 parser, result, LLVM::CConv::C))); 2588 2589 StringAttr nameAttr; 2590 SmallVector<OpAsmParser::Argument> entryArgs; 2591 SmallVector<DictionaryAttr> resultAttrs; 2592 SmallVector<Type> resultTypes; 2593 bool isVariadic; 2594 2595 auto signatureLocation = parser.getCurrentLocation(); 2596 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 2597 result.attributes) || 2598 function_interface_impl::parseFunctionSignature( 2599 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, 2600 resultAttrs)) 2601 return failure(); 2602 2603 SmallVector<Type> argTypes; 2604 for (auto &arg : entryArgs) 2605 argTypes.push_back(arg.type); 2606 auto type = 2607 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 2608 function_interface_impl::VariadicFlag(isVariadic)); 2609 if (!type) 2610 return failure(); 2611 result.addAttribute(getFunctionTypeAttrName(result.name), 2612 TypeAttr::get(type)); 2613 2614 if (succeeded(parser.parseOptionalKeyword("vscale_range"))) { 2615 int64_t minRange, maxRange; 2616 if (parser.parseLParen() || parser.parseInteger(minRange) || 2617 parser.parseComma() || parser.parseInteger(maxRange) || 2618 parser.parseRParen()) 2619 return failure(); 2620 auto intTy = IntegerType::get(parser.getContext(), 32); 2621 result.addAttribute( 2622 getVscaleRangeAttrName(result.name), 2623 LLVM::VScaleRangeAttr::get(parser.getContext(), 2624 IntegerAttr::get(intTy, minRange), 2625 IntegerAttr::get(intTy, maxRange))); 2626 } 2627 // Parse the optional comdat selector. 2628 if (succeeded(parser.parseOptionalKeyword("comdat"))) { 2629 SymbolRefAttr comdat; 2630 if (parser.parseLParen() || parser.parseAttribute(comdat) || 2631 parser.parseRParen()) 2632 return failure(); 2633 2634 result.addAttribute(getComdatAttrName(result.name), comdat); 2635 } 2636 2637 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 2638 return failure(); 2639 function_interface_impl::addArgAndResultAttrs( 2640 parser.getBuilder(), result, entryArgs, resultAttrs, 2641 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 2642 2643 auto *body = result.addRegion(); 2644 OptionalParseResult parseResult = 2645 parser.parseOptionalRegion(*body, entryArgs); 2646 return failure(parseResult.has_value() && failed(*parseResult)); 2647 } 2648 2649 // Print the LLVMFuncOp. Collects argument and result types and passes them to 2650 // helper functions. Drops "void" result since it cannot be parsed back. Skips 2651 // the external linkage since it is the default value. 2652 void LLVMFuncOp::print(OpAsmPrinter &p) { 2653 p << ' '; 2654 if (getLinkage() != LLVM::Linkage::External) 2655 p << stringifyLinkage(getLinkage()) << ' '; 2656 StringRef visibility = stringifyVisibility(getVisibility_()); 2657 if (!visibility.empty()) 2658 p << visibility << ' '; 2659 if (auto unnamedAddr = getUnnamedAddr()) { 2660 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 2661 if (!str.empty()) 2662 p << str << ' '; 2663 } 2664 if (getCConv() != LLVM::CConv::C) 2665 p << stringifyCConv(getCConv()) << ' '; 2666 2667 p.printSymbolName(getName()); 2668 2669 LLVMFunctionType fnType = getFunctionType(); 2670 SmallVector<Type, 8> argTypes; 2671 SmallVector<Type, 1> resTypes; 2672 argTypes.reserve(fnType.getNumParams()); 2673 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 2674 argTypes.push_back(fnType.getParamType(i)); 2675 2676 Type returnType = fnType.getReturnType(); 2677 if (!llvm::isa<LLVMVoidType>(returnType)) 2678 resTypes.push_back(returnType); 2679 2680 function_interface_impl::printFunctionSignature(p, *this, argTypes, 2681 isVarArg(), resTypes); 2682 2683 // Print vscale range if present 2684 if (std::optional<VScaleRangeAttr> vscale = getVscaleRange()) 2685 p << " vscale_range(" << vscale->getMinRange().getInt() << ", " 2686 << vscale->getMaxRange().getInt() << ')'; 2687 2688 // Print the optional comdat selector. 2689 if (auto comdat = getComdat()) 2690 p << " comdat(" << *comdat << ')'; 2691 2692 function_interface_impl::printFunctionAttributes( 2693 p, *this, 2694 {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), 2695 getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(), 2696 getComdatAttrName(), getUnnamedAddrAttrName(), 2697 getVscaleRangeAttrName()}); 2698 2699 // Print the body if this is not an external function. 2700 Region &body = getBody(); 2701 if (!body.empty()) { 2702 p << ' '; 2703 p.printRegion(body, /*printEntryBlockArgs=*/false, 2704 /*printBlockTerminators=*/true); 2705 } 2706 } 2707 2708 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2709 // - functions don't have 'common' linkage 2710 // - external functions have 'external' or 'extern_weak' linkage; 2711 // - vararg is (currently) only supported for external functions; 2712 LogicalResult LLVMFuncOp::verify() { 2713 if (getLinkage() == LLVM::Linkage::Common) 2714 return emitOpError() << "functions cannot have '" 2715 << stringifyLinkage(LLVM::Linkage::Common) 2716 << "' linkage"; 2717 2718 if (failed(verifyComdat(*this, getComdat()))) 2719 return failure(); 2720 2721 if (isExternal()) { 2722 if (getLinkage() != LLVM::Linkage::External && 2723 getLinkage() != LLVM::Linkage::ExternWeak) 2724 return emitOpError() << "external functions must have '" 2725 << stringifyLinkage(LLVM::Linkage::External) 2726 << "' or '" 2727 << stringifyLinkage(LLVM::Linkage::ExternWeak) 2728 << "' linkage"; 2729 return success(); 2730 } 2731 2732 // In LLVM IR, these attributes are composed by convention, not by design. 2733 if (isNoInline() && isAlwaysInline()) 2734 return emitError("no_inline and always_inline attributes are incompatible"); 2735 2736 if (isOptimizeNone() && !isNoInline()) 2737 return emitOpError("with optimize_none must also be no_inline"); 2738 2739 Type landingpadResultTy; 2740 StringRef diagnosticMessage; 2741 bool isLandingpadTypeConsistent = 2742 !walk([&](Operation *op) { 2743 const auto checkType = [&](Type type, StringRef errorMessage) { 2744 if (!landingpadResultTy) { 2745 landingpadResultTy = type; 2746 return WalkResult::advance(); 2747 } 2748 if (landingpadResultTy != type) { 2749 diagnosticMessage = errorMessage; 2750 return WalkResult::interrupt(); 2751 } 2752 return WalkResult::advance(); 2753 }; 2754 return TypeSwitch<Operation *, WalkResult>(op) 2755 .Case<LandingpadOp>([&](auto landingpad) { 2756 constexpr StringLiteral errorMessage = 2757 "'llvm.landingpad' should have a consistent result type " 2758 "inside a function"; 2759 return checkType(landingpad.getType(), errorMessage); 2760 }) 2761 .Case<ResumeOp>([&](auto resume) { 2762 constexpr StringLiteral errorMessage = 2763 "'llvm.resume' should have a consistent input type inside a " 2764 "function"; 2765 return checkType(resume.getValue().getType(), errorMessage); 2766 }) 2767 .Default([](auto) { return WalkResult::skip(); }); 2768 }).wasInterrupted(); 2769 if (!isLandingpadTypeConsistent) { 2770 assert(!diagnosticMessage.empty() && 2771 "Expecting a non-empty diagnostic message"); 2772 return emitError(diagnosticMessage); 2773 } 2774 2775 return success(); 2776 } 2777 2778 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2779 /// - entry block arguments are of LLVM types. 2780 LogicalResult LLVMFuncOp::verifyRegions() { 2781 if (isExternal()) 2782 return success(); 2783 2784 unsigned numArguments = getFunctionType().getNumParams(); 2785 Block &entryBlock = front(); 2786 for (unsigned i = 0; i < numArguments; ++i) { 2787 Type argType = entryBlock.getArgument(i).getType(); 2788 if (!isCompatibleType(argType)) 2789 return emitOpError("entry block argument #") 2790 << i << " is not of LLVM type"; 2791 } 2792 2793 return success(); 2794 } 2795 2796 Region *LLVMFuncOp::getCallableRegion() { 2797 if (isExternal()) 2798 return nullptr; 2799 return &getBody(); 2800 } 2801 2802 //===----------------------------------------------------------------------===// 2803 // UndefOp. 2804 //===----------------------------------------------------------------------===// 2805 2806 /// Fold an undef operation to a dedicated undef attribute. 2807 OpFoldResult LLVM::UndefOp::fold(FoldAdaptor) { 2808 return LLVM::UndefAttr::get(getContext()); 2809 } 2810 2811 //===----------------------------------------------------------------------===// 2812 // PoisonOp. 2813 //===----------------------------------------------------------------------===// 2814 2815 /// Fold a poison operation to a dedicated poison attribute. 2816 OpFoldResult LLVM::PoisonOp::fold(FoldAdaptor) { 2817 return LLVM::PoisonAttr::get(getContext()); 2818 } 2819 2820 //===----------------------------------------------------------------------===// 2821 // ZeroOp. 2822 //===----------------------------------------------------------------------===// 2823 2824 LogicalResult LLVM::ZeroOp::verify() { 2825 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) 2826 if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit)) 2827 return emitOpError() 2828 << "target extension type does not support zero-initializer"; 2829 2830 return success(); 2831 } 2832 2833 /// Fold a zero operation to a builtin zero attribute when possible and fall 2834 /// back to a dedicated zero attribute. 2835 OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) { 2836 OpFoldResult result = Builder(getContext()).getZeroAttr(getType()); 2837 if (result) 2838 return result; 2839 return LLVM::ZeroAttr::get(getContext()); 2840 } 2841 2842 //===----------------------------------------------------------------------===// 2843 // ConstantOp. 2844 //===----------------------------------------------------------------------===// 2845 2846 /// Compute the total number of elements in the given type, also taking into 2847 /// account nested types. Supported types are `VectorType`, `LLVMArrayType` and 2848 /// `LLVMFixedVectorType`. Everything else is treated as a scalar. 2849 static int64_t getNumElements(Type t) { 2850 if (auto vecType = dyn_cast<VectorType>(t)) 2851 return vecType.getNumElements() * getNumElements(vecType.getElementType()); 2852 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t)) 2853 return arrayType.getNumElements() * 2854 getNumElements(arrayType.getElementType()); 2855 if (auto vecType = dyn_cast<LLVMFixedVectorType>(t)) 2856 return vecType.getNumElements() * getNumElements(vecType.getElementType()); 2857 assert(!isa<LLVM::LLVMScalableVectorType>(t) && 2858 "number of elements of a scalable vector type is unknown"); 2859 return 1; 2860 } 2861 2862 /// Check if the given type is a scalable vector type or a vector/array type 2863 /// that contains a nested scalable vector type. 2864 static bool hasScalableVectorType(Type t) { 2865 if (isa<LLVM::LLVMScalableVectorType>(t)) 2866 return true; 2867 if (auto vecType = dyn_cast<VectorType>(t)) { 2868 if (vecType.isScalable()) 2869 return true; 2870 return hasScalableVectorType(vecType.getElementType()); 2871 } 2872 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t)) 2873 return hasScalableVectorType(arrayType.getElementType()); 2874 if (auto vecType = dyn_cast<LLVMFixedVectorType>(t)) 2875 return hasScalableVectorType(vecType.getElementType()); 2876 return false; 2877 } 2878 2879 LogicalResult LLVM::ConstantOp::verify() { 2880 if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) { 2881 auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType()); 2882 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 2883 !arrayType.getElementType().isInteger(8)) { 2884 return emitOpError() << "expected array type of " 2885 << sAttr.getValue().size() 2886 << " i8 elements for the string constant"; 2887 } 2888 return success(); 2889 } 2890 if (auto structType = dyn_cast<LLVMStructType>(getType())) { 2891 auto arrayAttr = dyn_cast<ArrayAttr>(getValue()); 2892 if (!arrayAttr) { 2893 return emitOpError() << "expected array attribute for a struct constant"; 2894 } 2895 2896 ArrayRef<Type> elementTypes = structType.getBody(); 2897 if (arrayAttr.size() != elementTypes.size()) { 2898 return emitOpError() << "expected array attribute of size " 2899 << elementTypes.size(); 2900 } 2901 for (auto elementTy : elementTypes) { 2902 if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) { 2903 return emitOpError() << "expected struct element types to be floating " 2904 "point type or integer type"; 2905 } 2906 } 2907 2908 for (size_t i = 0; i < elementTypes.size(); ++i) { 2909 Attribute element = arrayAttr[i]; 2910 if (!isa<IntegerAttr, FloatAttr>(element)) { 2911 return emitOpError() 2912 << "expected struct element attribute types to be floating " 2913 "point type or integer type"; 2914 } 2915 auto elementType = cast<TypedAttr>(element).getType(); 2916 if (elementType != elementTypes[i]) { 2917 return emitOpError() 2918 << "struct element at index " << i << " is of wrong type"; 2919 } 2920 } 2921 2922 return success(); 2923 } 2924 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) { 2925 return emitOpError() << "does not support target extension type."; 2926 } 2927 2928 // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr. 2929 if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) { 2930 if (!llvm::isa<IntegerType>(getType())) 2931 return emitOpError() << "expected integer type"; 2932 } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) { 2933 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); 2934 unsigned floatWidth = APFloat::getSizeInBits(sem); 2935 if (auto floatTy = dyn_cast<FloatType>(getType())) { 2936 if (floatTy.getWidth() != floatWidth) { 2937 return emitOpError() << "expected float type of width " << floatWidth; 2938 } 2939 } 2940 // See the comment for getLLVMConstant for more details about why 8-bit 2941 // floats can be represented by integers. 2942 if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) { 2943 return emitOpError() << "expected integer type of width " << floatWidth; 2944 } 2945 } else if (isa<ElementsAttr, ArrayAttr>(getValue())) { 2946 if (hasScalableVectorType(getType())) { 2947 // The exact number of elements of a scalable vector is unknown, so we 2948 // allow only splat attributes. 2949 auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue()); 2950 if (!splatElementsAttr) 2951 return emitOpError() 2952 << "scalable vector type requires a splat attribute"; 2953 return success(); 2954 } 2955 if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>( 2956 getType())) 2957 return emitOpError() << "expected vector or array type"; 2958 // The number of elements of the attribute and the type must match. 2959 int64_t attrNumElements; 2960 if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) 2961 attrNumElements = elementsAttr.getNumElements(); 2962 else 2963 attrNumElements = cast<ArrayAttr>(getValue()).size(); 2964 if (getNumElements(getType()) != attrNumElements) 2965 return emitOpError() 2966 << "type and attribute have a different number of elements: " 2967 << getNumElements(getType()) << " vs. " << attrNumElements; 2968 } else { 2969 return emitOpError() 2970 << "only supports integer, float, string or elements attributes"; 2971 } 2972 2973 return success(); 2974 } 2975 2976 bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) { 2977 // The value's type must be the same as the provided type. 2978 auto typedAttr = dyn_cast<TypedAttr>(value); 2979 if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type)) 2980 return false; 2981 // The value's type must be an LLVM compatible type. 2982 if (!isCompatibleType(type)) 2983 return false; 2984 // TODO: Add support for additional attributes kinds once needed. 2985 return isa<IntegerAttr, FloatAttr, ElementsAttr>(value); 2986 } 2987 2988 ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value, 2989 Type type, Location loc) { 2990 if (isBuildableWith(value, type)) 2991 return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value)); 2992 return nullptr; 2993 } 2994 2995 // Constant op constant-folds to its value. 2996 OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); } 2997 2998 //===----------------------------------------------------------------------===// 2999 // AtomicRMWOp 3000 //===----------------------------------------------------------------------===// 3001 3002 void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, 3003 AtomicBinOp binOp, Value ptr, Value val, 3004 AtomicOrdering ordering, StringRef syncscope, 3005 unsigned alignment, bool isVolatile) { 3006 build(builder, state, val.getType(), binOp, ptr, val, ordering, 3007 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr, 3008 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile, 3009 /*access_groups=*/nullptr, 3010 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); 3011 } 3012 3013 LogicalResult AtomicRMWOp::verify() { 3014 auto valType = getVal().getType(); 3015 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub || 3016 getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) { 3017 if (isCompatibleVectorType(valType)) { 3018 if (isScalableVectorType(valType)) 3019 return emitOpError("expected LLVM IR fixed vector type"); 3020 Type elemType = getVectorElementType(valType); 3021 if (!isCompatibleFloatingPointType(elemType)) 3022 return emitOpError( 3023 "expected LLVM IR floating point type for vector element"); 3024 } else if (!isCompatibleFloatingPointType(valType)) { 3025 return emitOpError("expected LLVM IR floating point type"); 3026 } 3027 } else if (getBinOp() == AtomicBinOp::xchg) { 3028 DataLayout dataLayout = DataLayout::closest(*this); 3029 if (!isTypeCompatibleWithAtomicOp(valType, dataLayout)) 3030 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 3031 } else { 3032 auto intType = llvm::dyn_cast<IntegerType>(valType); 3033 unsigned intBitWidth = intType ? intType.getWidth() : 0; 3034 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 3035 intBitWidth != 64) 3036 return emitOpError("expected LLVM IR integer type"); 3037 } 3038 3039 if (static_cast<unsigned>(getOrdering()) < 3040 static_cast<unsigned>(AtomicOrdering::monotonic)) 3041 return emitOpError() << "expected at least '" 3042 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 3043 << "' ordering"; 3044 3045 return success(); 3046 } 3047 3048 //===----------------------------------------------------------------------===// 3049 // AtomicCmpXchgOp 3050 //===----------------------------------------------------------------------===// 3051 3052 /// Returns an LLVM struct type that contains a value type and a boolean type. 3053 static LLVMStructType getValAndBoolStructType(Type valType) { 3054 auto boolType = IntegerType::get(valType.getContext(), 1); 3055 return LLVMStructType::getLiteral(valType.getContext(), {valType, boolType}); 3056 } 3057 3058 void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state, 3059 Value ptr, Value cmp, Value val, 3060 AtomicOrdering successOrdering, 3061 AtomicOrdering failureOrdering, StringRef syncscope, 3062 unsigned alignment, bool isWeak, bool isVolatile) { 3063 build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val, 3064 successOrdering, failureOrdering, 3065 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr, 3066 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak, 3067 isVolatile, /*access_groups=*/nullptr, 3068 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); 3069 } 3070 3071 LogicalResult AtomicCmpXchgOp::verify() { 3072 auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()); 3073 if (!ptrType) 3074 return emitOpError("expected LLVM IR pointer type for operand #0"); 3075 auto valType = getVal().getType(); 3076 DataLayout dataLayout = DataLayout::closest(*this); 3077 if (!isTypeCompatibleWithAtomicOp(valType, dataLayout)) 3078 return emitOpError("unexpected LLVM IR type"); 3079 if (getSuccessOrdering() < AtomicOrdering::monotonic || 3080 getFailureOrdering() < AtomicOrdering::monotonic) 3081 return emitOpError("ordering must be at least 'monotonic'"); 3082 if (getFailureOrdering() == AtomicOrdering::release || 3083 getFailureOrdering() == AtomicOrdering::acq_rel) 3084 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 3085 return success(); 3086 } 3087 3088 //===----------------------------------------------------------------------===// 3089 // FenceOp 3090 //===----------------------------------------------------------------------===// 3091 3092 void FenceOp::build(OpBuilder &builder, OperationState &state, 3093 AtomicOrdering ordering, StringRef syncscope) { 3094 build(builder, state, ordering, 3095 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope)); 3096 } 3097 3098 LogicalResult FenceOp::verify() { 3099 if (getOrdering() == AtomicOrdering::not_atomic || 3100 getOrdering() == AtomicOrdering::unordered || 3101 getOrdering() == AtomicOrdering::monotonic) 3102 return emitOpError("can be given only acquire, release, acq_rel, " 3103 "and seq_cst orderings"); 3104 return success(); 3105 } 3106 3107 //===----------------------------------------------------------------------===// 3108 // Verifier for extension ops 3109 //===----------------------------------------------------------------------===// 3110 3111 /// Verifies that the given extension operation operates on consistent scalars 3112 /// or vectors, and that the target width is larger than the input width. 3113 template <class ExtOp> 3114 static LogicalResult verifyExtOp(ExtOp op) { 3115 IntegerType inputType, outputType; 3116 if (isCompatibleVectorType(op.getArg().getType())) { 3117 if (!isCompatibleVectorType(op.getResult().getType())) 3118 return op.emitError( 3119 "input type is a vector but output type is an integer"); 3120 if (getVectorNumElements(op.getArg().getType()) != 3121 getVectorNumElements(op.getResult().getType())) 3122 return op.emitError("input and output vectors are of incompatible shape"); 3123 // Because this is a CastOp, the element of vectors is guaranteed to be an 3124 // integer. 3125 inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType())); 3126 outputType = 3127 cast<IntegerType>(getVectorElementType(op.getResult().getType())); 3128 } else { 3129 // Because this is a CastOp and arg is not a vector, arg is guaranteed to be 3130 // an integer. 3131 inputType = cast<IntegerType>(op.getArg().getType()); 3132 outputType = dyn_cast<IntegerType>(op.getResult().getType()); 3133 if (!outputType) 3134 return op.emitError( 3135 "input type is an integer but output type is a vector"); 3136 } 3137 3138 if (outputType.getWidth() <= inputType.getWidth()) 3139 return op.emitError("integer width of the output type is smaller or " 3140 "equal to the integer width of the input type"); 3141 return success(); 3142 } 3143 3144 //===----------------------------------------------------------------------===// 3145 // ZExtOp 3146 //===----------------------------------------------------------------------===// 3147 3148 LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); } 3149 3150 OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) { 3151 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg()); 3152 if (!arg) 3153 return {}; 3154 3155 size_t targetSize = cast<IntegerType>(getType()).getWidth(); 3156 return IntegerAttr::get(getType(), arg.getValue().zext(targetSize)); 3157 } 3158 3159 //===----------------------------------------------------------------------===// 3160 // SExtOp 3161 //===----------------------------------------------------------------------===// 3162 3163 LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); } 3164 3165 //===----------------------------------------------------------------------===// 3166 // Folder and verifier for LLVM::BitcastOp 3167 //===----------------------------------------------------------------------===// 3168 3169 /// Folds a cast op that can be chained. 3170 template <typename T> 3171 static OpFoldResult foldChainableCast(T castOp, 3172 typename T::FoldAdaptor adaptor) { 3173 // cast(x : T0, T0) -> x 3174 if (castOp.getArg().getType() == castOp.getType()) 3175 return castOp.getArg(); 3176 if (auto prev = castOp.getArg().template getDefiningOp<T>()) { 3177 // cast(cast(x : T0, T1), T0) -> x 3178 if (prev.getArg().getType() == castOp.getType()) 3179 return prev.getArg(); 3180 // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2) 3181 castOp.getArgMutable().set(prev.getArg()); 3182 return Value{castOp}; 3183 } 3184 return {}; 3185 } 3186 3187 OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) { 3188 return foldChainableCast(*this, adaptor); 3189 } 3190 3191 LogicalResult LLVM::BitcastOp::verify() { 3192 auto resultType = llvm::dyn_cast<LLVMPointerType>( 3193 extractVectorElementType(getResult().getType())); 3194 auto sourceType = llvm::dyn_cast<LLVMPointerType>( 3195 extractVectorElementType(getArg().getType())); 3196 3197 // If one of the types is a pointer (or vector of pointers), then 3198 // both source and result type have to be pointers. 3199 if (static_cast<bool>(resultType) != static_cast<bool>(sourceType)) 3200 return emitOpError("can only cast pointers from and to pointers"); 3201 3202 if (!resultType) 3203 return success(); 3204 3205 auto isVector = 3206 llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>; 3207 3208 // Due to bitcast requiring both operands to be of the same size, it is not 3209 // possible for only one of the two to be a pointer of vectors. 3210 if (isVector(getResult().getType()) && !isVector(getArg().getType())) 3211 return emitOpError("cannot cast pointer to vector of pointers"); 3212 3213 if (!isVector(getResult().getType()) && isVector(getArg().getType())) 3214 return emitOpError("cannot cast vector of pointers to pointer"); 3215 3216 // Bitcast cannot cast between pointers of different address spaces. 3217 // 'llvm.addrspacecast' must be used for this purpose instead. 3218 if (resultType.getAddressSpace() != sourceType.getAddressSpace()) 3219 return emitOpError("cannot cast pointers of different address spaces, " 3220 "use 'llvm.addrspacecast' instead"); 3221 3222 return success(); 3223 } 3224 3225 //===----------------------------------------------------------------------===// 3226 // Folder for LLVM::AddrSpaceCastOp 3227 //===----------------------------------------------------------------------===// 3228 3229 OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) { 3230 return foldChainableCast(*this, adaptor); 3231 } 3232 3233 Value LLVM::AddrSpaceCastOp::getViewSource() { return getArg(); } 3234 3235 //===----------------------------------------------------------------------===// 3236 // Folder for LLVM::GEPOp 3237 //===----------------------------------------------------------------------===// 3238 3239 OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) { 3240 GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(), 3241 adaptor.getDynamicIndices()); 3242 3243 // gep %x:T, 0 -> %x 3244 if (getBase().getType() == getType() && indices.size() == 1) 3245 if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0])) 3246 if (integer.getValue().isZero()) 3247 return getBase(); 3248 3249 // Canonicalize any dynamic indices of constant value to constant indices. 3250 bool changed = false; 3251 SmallVector<GEPArg> gepArgs; 3252 for (auto iter : llvm::enumerate(indices)) { 3253 auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value()); 3254 // Constant indices can only be int32_t, so if integer does not fit we 3255 // are forced to keep it dynamic, despite being a constant. 3256 if (!indices.isDynamicIndex(iter.index()) || !integer || 3257 !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) { 3258 3259 PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()]; 3260 if (Value val = llvm::dyn_cast_if_present<Value>(existing)) 3261 gepArgs.emplace_back(val); 3262 else 3263 gepArgs.emplace_back(cast<IntegerAttr>(existing).getInt()); 3264 3265 continue; 3266 } 3267 3268 changed = true; 3269 gepArgs.emplace_back(integer.getInt()); 3270 } 3271 if (changed) { 3272 SmallVector<int32_t> rawConstantIndices; 3273 SmallVector<Value> dynamicIndices; 3274 destructureIndices(getElemType(), gepArgs, rawConstantIndices, 3275 dynamicIndices); 3276 3277 getDynamicIndicesMutable().assign(dynamicIndices); 3278 setRawConstantIndices(rawConstantIndices); 3279 return Value{*this}; 3280 } 3281 3282 return {}; 3283 } 3284 3285 Value LLVM::GEPOp::getViewSource() { return getBase(); } 3286 3287 //===----------------------------------------------------------------------===// 3288 // ShlOp 3289 //===----------------------------------------------------------------------===// 3290 3291 OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) { 3292 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()); 3293 if (!rhs) 3294 return {}; 3295 3296 if (rhs.getValue().getZExtValue() >= 3297 getLhs().getType().getIntOrFloatBitWidth()) 3298 return {}; // TODO: Fold into poison. 3299 3300 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs()); 3301 if (!lhs) 3302 return {}; 3303 3304 return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue())); 3305 } 3306 3307 //===----------------------------------------------------------------------===// 3308 // OrOp 3309 //===----------------------------------------------------------------------===// 3310 3311 OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) { 3312 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs()); 3313 if (!lhs) 3314 return {}; 3315 3316 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()); 3317 if (!rhs) 3318 return {}; 3319 3320 return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue()); 3321 } 3322 3323 //===----------------------------------------------------------------------===// 3324 // CallIntrinsicOp 3325 //===----------------------------------------------------------------------===// 3326 3327 LogicalResult CallIntrinsicOp::verify() { 3328 if (!getIntrin().starts_with("llvm.")) 3329 return emitOpError() << "intrinsic name must start with 'llvm.'"; 3330 if (failed(verifyOperandBundles(*this))) 3331 return failure(); 3332 return success(); 3333 } 3334 3335 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, 3336 mlir::StringAttr intrin, mlir::ValueRange args) { 3337 build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args, 3338 FastmathFlagsAttr{}, 3339 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}); 3340 } 3341 3342 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, 3343 mlir::StringAttr intrin, mlir::ValueRange args, 3344 mlir::LLVM::FastmathFlagsAttr fastMathFlags) { 3345 build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args, 3346 fastMathFlags, 3347 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}); 3348 } 3349 3350 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, 3351 mlir::Type resultType, mlir::StringAttr intrin, 3352 mlir::ValueRange args) { 3353 build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{}, 3354 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}); 3355 } 3356 3357 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, 3358 mlir::TypeRange resultTypes, 3359 mlir::StringAttr intrin, mlir::ValueRange args, 3360 mlir::LLVM::FastmathFlagsAttr fastMathFlags) { 3361 build(builder, state, resultTypes, intrin, args, fastMathFlags, 3362 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}); 3363 } 3364 3365 //===----------------------------------------------------------------------===// 3366 // OpAsmDialectInterface 3367 //===----------------------------------------------------------------------===// 3368 3369 namespace { 3370 struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface { 3371 using OpAsmDialectInterface::OpAsmDialectInterface; 3372 3373 AliasResult getAlias(Attribute attr, raw_ostream &os) const override { 3374 return TypeSwitch<Attribute, AliasResult>(attr) 3375 .Case<AccessGroupAttr, AliasScopeAttr, AliasScopeDomainAttr, 3376 DIBasicTypeAttr, DICommonBlockAttr, DICompileUnitAttr, 3377 DICompositeTypeAttr, DIDerivedTypeAttr, DIFileAttr, 3378 DIGlobalVariableAttr, DIGlobalVariableExpressionAttr, 3379 DIImportedEntityAttr, DILabelAttr, DILexicalBlockAttr, 3380 DILexicalBlockFileAttr, DILocalVariableAttr, DIModuleAttr, 3381 DINamespaceAttr, DINullTypeAttr, DIStringTypeAttr, 3382 DISubprogramAttr, DISubroutineTypeAttr, LoopAnnotationAttr, 3383 LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr, 3384 LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr, 3385 LoopPipelineAttr, LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr, 3386 TBAATagAttr, TBAATypeDescriptorAttr>([&](auto attr) { 3387 os << decltype(attr)::getMnemonic(); 3388 return AliasResult::OverridableAlias; 3389 }) 3390 .Default([](Attribute) { return AliasResult::NoAlias; }); 3391 } 3392 }; 3393 } // namespace 3394 3395 //===----------------------------------------------------------------------===// 3396 // LinkerOptionsOp 3397 //===----------------------------------------------------------------------===// 3398 3399 LogicalResult LinkerOptionsOp::verify() { 3400 if (mlir::Operation *parentOp = (*this)->getParentOp(); 3401 parentOp && !satisfiesLLVMModule(parentOp)) 3402 return emitOpError("must appear at the module level"); 3403 return success(); 3404 } 3405 3406 //===----------------------------------------------------------------------===// 3407 // InlineAsmOp 3408 //===----------------------------------------------------------------------===// 3409 3410 void InlineAsmOp::getEffects( 3411 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 3412 &effects) { 3413 if (getHasSideEffects()) { 3414 effects.emplace_back(MemoryEffects::Write::get()); 3415 effects.emplace_back(MemoryEffects::Read::get()); 3416 } 3417 } 3418 3419 //===----------------------------------------------------------------------===// 3420 // AssumeOp (intrinsic) 3421 //===----------------------------------------------------------------------===// 3422 3423 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, 3424 mlir::Value cond) { 3425 return build(builder, state, cond, /*op_bundle_operands=*/{}, 3426 /*op_bundle_tags=*/ArrayAttr{}); 3427 } 3428 3429 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, 3430 Value cond, 3431 ArrayRef<llvm::OperandBundleDefT<Value>> opBundles) { 3432 SmallVector<ValueRange> opBundleOperands; 3433 SmallVector<Attribute> opBundleTags; 3434 opBundleOperands.reserve(opBundles.size()); 3435 opBundleTags.reserve(opBundles.size()); 3436 3437 for (const llvm::OperandBundleDefT<Value> &bundle : opBundles) { 3438 opBundleOperands.emplace_back(bundle.inputs()); 3439 opBundleTags.push_back( 3440 StringAttr::get(builder.getContext(), bundle.getTag())); 3441 } 3442 3443 auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags); 3444 return build(builder, state, cond, opBundleOperands, opBundleTagsAttr); 3445 } 3446 3447 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, 3448 Value cond, llvm::StringRef tag, ValueRange args) { 3449 llvm::OperandBundleDefT<Value> opBundle( 3450 tag.str(), SmallVector<Value>(args.begin(), args.end())); 3451 return build(builder, state, cond, opBundle); 3452 } 3453 3454 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, 3455 Value cond, AssumeAlignTag, Value ptr, Value align) { 3456 return build(builder, state, cond, "align", ValueRange{ptr, align}); 3457 } 3458 3459 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, 3460 Value cond, AssumeSeparateStorageTag, Value ptr1, 3461 Value ptr2) { 3462 return build(builder, state, cond, "separate_storage", 3463 ValueRange{ptr1, ptr2}); 3464 } 3465 3466 LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); } 3467 3468 //===----------------------------------------------------------------------===// 3469 // masked_gather (intrinsic) 3470 //===----------------------------------------------------------------------===// 3471 3472 LogicalResult LLVM::masked_gather::verify() { 3473 auto ptrsVectorType = getPtrs().getType(); 3474 Type expectedPtrsVectorType = 3475 LLVM::getVectorType(extractVectorElementType(ptrsVectorType), 3476 LLVM::getVectorNumElements(getRes().getType())); 3477 // Vector of pointers type should match result vector type, other than the 3478 // element type. 3479 if (ptrsVectorType != expectedPtrsVectorType) 3480 return emitOpError("expected operand #1 type to be ") 3481 << expectedPtrsVectorType; 3482 return success(); 3483 } 3484 3485 //===----------------------------------------------------------------------===// 3486 // masked_scatter (intrinsic) 3487 //===----------------------------------------------------------------------===// 3488 3489 LogicalResult LLVM::masked_scatter::verify() { 3490 auto ptrsVectorType = getPtrs().getType(); 3491 Type expectedPtrsVectorType = 3492 LLVM::getVectorType(extractVectorElementType(ptrsVectorType), 3493 LLVM::getVectorNumElements(getValue().getType())); 3494 // Vector of pointers type should match value vector type, other than the 3495 // element type. 3496 if (ptrsVectorType != expectedPtrsVectorType) 3497 return emitOpError("expected operand #2 type to be ") 3498 << expectedPtrsVectorType; 3499 return success(); 3500 } 3501 3502 //===----------------------------------------------------------------------===// 3503 // LLVMDialect initialization, type parsing, and registration. 3504 //===----------------------------------------------------------------------===// 3505 3506 void LLVMDialect::initialize() { 3507 registerAttributes(); 3508 3509 // clang-format off 3510 addTypes<LLVMVoidType, 3511 LLVMPPCFP128Type, 3512 LLVMTokenType, 3513 LLVMLabelType, 3514 LLVMMetadataType>(); 3515 // clang-format on 3516 registerTypes(); 3517 3518 addOperations< 3519 #define GET_OP_LIST 3520 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 3521 , 3522 #define GET_OP_LIST 3523 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" 3524 >(); 3525 3526 // Support unknown operations because not all LLVM operations are registered. 3527 allowUnknownOperations(); 3528 // clang-format off 3529 addInterfaces<LLVMOpAsmDialectInterface>(); 3530 // clang-format on 3531 declarePromisedInterface<DialectInlinerInterface, LLVMDialect>(); 3532 } 3533 3534 #define GET_OP_CLASSES 3535 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 3536 3537 #define GET_OP_CLASSES 3538 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" 3539 3540 LogicalResult LLVMDialect::verifyDataLayoutString( 3541 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 3542 llvm::Expected<llvm::DataLayout> maybeDataLayout = 3543 llvm::DataLayout::parse(descr); 3544 if (maybeDataLayout) 3545 return success(); 3546 3547 std::string message; 3548 llvm::raw_string_ostream messageStream(message); 3549 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 3550 reportError("invalid data layout descriptor: " + message); 3551 return failure(); 3552 } 3553 3554 /// Verify LLVM dialect attributes. 3555 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 3556 NamedAttribute attr) { 3557 // If the data layout attribute is present, it must use the LLVM data layout 3558 // syntax. Try parsing it and report errors in case of failure. Users of this 3559 // attribute may assume it is well-formed and can pass it to the (asserting) 3560 // llvm::DataLayout constructor. 3561 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) 3562 return success(); 3563 if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue())) 3564 return verifyDataLayoutString( 3565 stringAttr.getValue(), 3566 [op](const Twine &message) { op->emitOpError() << message.str(); }); 3567 3568 return op->emitOpError() << "expected '" 3569 << LLVM::LLVMDialect::getDataLayoutAttrName() 3570 << "' to be a string attributes"; 3571 } 3572 3573 LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op, 3574 Type paramType, 3575 NamedAttribute paramAttr) { 3576 // LLVM attribute may be attached to a result of operation that has not been 3577 // converted to LLVM dialect yet, so the result may have a type with unknown 3578 // representation in LLVM dialect type space. In this case we cannot verify 3579 // whether the attribute may be 3580 bool verifyValueType = isCompatibleType(paramType); 3581 StringAttr name = paramAttr.getName(); 3582 3583 auto checkUnitAttrType = [&]() -> LogicalResult { 3584 if (!llvm::isa<UnitAttr>(paramAttr.getValue())) 3585 return op->emitError() << name << " should be a unit attribute"; 3586 return success(); 3587 }; 3588 auto checkTypeAttrType = [&]() -> LogicalResult { 3589 if (!llvm::isa<TypeAttr>(paramAttr.getValue())) 3590 return op->emitError() << name << " should be a type attribute"; 3591 return success(); 3592 }; 3593 auto checkIntegerAttrType = [&]() -> LogicalResult { 3594 if (!llvm::isa<IntegerAttr>(paramAttr.getValue())) 3595 return op->emitError() << name << " should be an integer attribute"; 3596 return success(); 3597 }; 3598 auto checkPointerType = [&]() -> LogicalResult { 3599 if (!llvm::isa<LLVMPointerType>(paramType)) 3600 return op->emitError() 3601 << name << " attribute attached to non-pointer LLVM type"; 3602 return success(); 3603 }; 3604 auto checkIntegerType = [&]() -> LogicalResult { 3605 if (!llvm::isa<IntegerType>(paramType)) 3606 return op->emitError() 3607 << name << " attribute attached to non-integer LLVM type"; 3608 return success(); 3609 }; 3610 auto checkPointerTypeMatches = [&]() -> LogicalResult { 3611 if (failed(checkPointerType())) 3612 return failure(); 3613 3614 return success(); 3615 }; 3616 3617 // Check a unit attribute that is attached to a pointer value. 3618 if (name == LLVMDialect::getNoAliasAttrName() || 3619 name == LLVMDialect::getReadonlyAttrName() || 3620 name == LLVMDialect::getReadnoneAttrName() || 3621 name == LLVMDialect::getWriteOnlyAttrName() || 3622 name == LLVMDialect::getNestAttrName() || 3623 name == LLVMDialect::getNoCaptureAttrName() || 3624 name == LLVMDialect::getNoFreeAttrName() || 3625 name == LLVMDialect::getNonNullAttrName()) { 3626 if (failed(checkUnitAttrType())) 3627 return failure(); 3628 if (verifyValueType && failed(checkPointerType())) 3629 return failure(); 3630 return success(); 3631 } 3632 3633 // Check a type attribute that is attached to a pointer value. 3634 if (name == LLVMDialect::getStructRetAttrName() || 3635 name == LLVMDialect::getByValAttrName() || 3636 name == LLVMDialect::getByRefAttrName() || 3637 name == LLVMDialect::getInAllocaAttrName() || 3638 name == LLVMDialect::getPreallocatedAttrName()) { 3639 if (failed(checkTypeAttrType())) 3640 return failure(); 3641 if (verifyValueType && failed(checkPointerTypeMatches())) 3642 return failure(); 3643 return success(); 3644 } 3645 3646 // Check a unit attribute that is attached to an integer value. 3647 if (name == LLVMDialect::getSExtAttrName() || 3648 name == LLVMDialect::getZExtAttrName()) { 3649 if (failed(checkUnitAttrType())) 3650 return failure(); 3651 if (verifyValueType && failed(checkIntegerType())) 3652 return failure(); 3653 return success(); 3654 } 3655 3656 // Check an integer attribute that is attached to a pointer value. 3657 if (name == LLVMDialect::getAlignAttrName() || 3658 name == LLVMDialect::getDereferenceableAttrName() || 3659 name == LLVMDialect::getDereferenceableOrNullAttrName() || 3660 name == LLVMDialect::getStackAlignmentAttrName()) { 3661 if (failed(checkIntegerAttrType())) 3662 return failure(); 3663 if (verifyValueType && failed(checkPointerType())) 3664 return failure(); 3665 return success(); 3666 } 3667 3668 // Check a unit attribute that can be attached to arbitrary types. 3669 if (name == LLVMDialect::getNoUndefAttrName() || 3670 name == LLVMDialect::getInRegAttrName() || 3671 name == LLVMDialect::getReturnedAttrName()) 3672 return checkUnitAttrType(); 3673 3674 return success(); 3675 } 3676 3677 /// Verify LLVMIR function argument attributes. 3678 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 3679 unsigned regionIdx, 3680 unsigned argIdx, 3681 NamedAttribute argAttr) { 3682 auto funcOp = dyn_cast<FunctionOpInterface>(op); 3683 if (!funcOp) 3684 return success(); 3685 Type argType = funcOp.getArgumentTypes()[argIdx]; 3686 3687 return verifyParameterAttribute(op, argType, argAttr); 3688 } 3689 3690 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, 3691 unsigned regionIdx, 3692 unsigned resIdx, 3693 NamedAttribute resAttr) { 3694 auto funcOp = dyn_cast<FunctionOpInterface>(op); 3695 if (!funcOp) 3696 return success(); 3697 Type resType = funcOp.getResultTypes()[resIdx]; 3698 3699 // Check to see if this function has a void return with a result attribute 3700 // to it. It isn't clear what semantics we would assign to that. 3701 if (llvm::isa<LLVMVoidType>(resType)) 3702 return op->emitError() << "cannot attach result attributes to functions " 3703 "with a void return"; 3704 3705 // Check to see if this attribute is allowed as a result attribute. Only 3706 // explicitly forbidden LLVM attributes will cause an error. 3707 auto name = resAttr.getName(); 3708 if (name == LLVMDialect::getAllocAlignAttrName() || 3709 name == LLVMDialect::getAllocatedPointerAttrName() || 3710 name == LLVMDialect::getByValAttrName() || 3711 name == LLVMDialect::getByRefAttrName() || 3712 name == LLVMDialect::getInAllocaAttrName() || 3713 name == LLVMDialect::getNestAttrName() || 3714 name == LLVMDialect::getNoCaptureAttrName() || 3715 name == LLVMDialect::getNoFreeAttrName() || 3716 name == LLVMDialect::getPreallocatedAttrName() || 3717 name == LLVMDialect::getReadnoneAttrName() || 3718 name == LLVMDialect::getReadonlyAttrName() || 3719 name == LLVMDialect::getReturnedAttrName() || 3720 name == LLVMDialect::getStackAlignmentAttrName() || 3721 name == LLVMDialect::getStructRetAttrName() || 3722 name == LLVMDialect::getWriteOnlyAttrName()) 3723 return op->emitError() << name << " is not a valid result attribute"; 3724 return verifyParameterAttribute(op, resType, resAttr); 3725 } 3726 3727 Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value, 3728 Type type, Location loc) { 3729 // If this was folded from an operation other than llvm.mlir.constant, it 3730 // should be materialized as such. Note that an llvm.mlir.zero may fold into 3731 // a builtin zero attribute and thus will materialize as a llvm.mlir.constant. 3732 if (auto symbol = dyn_cast<FlatSymbolRefAttr>(value)) 3733 if (isa<LLVM::LLVMPointerType>(type)) 3734 return builder.create<LLVM::AddressOfOp>(loc, type, symbol); 3735 if (isa<LLVM::UndefAttr>(value)) 3736 return builder.create<LLVM::UndefOp>(loc, type); 3737 if (isa<LLVM::PoisonAttr>(value)) 3738 return builder.create<LLVM::PoisonOp>(loc, type); 3739 if (isa<LLVM::ZeroAttr>(value)) 3740 return builder.create<LLVM::ZeroOp>(loc, type); 3741 // Otherwise try materializing it as a regular llvm.mlir.constant op. 3742 return LLVM::ConstantOp::materialize(builder, value, type, loc); 3743 } 3744 3745 //===----------------------------------------------------------------------===// 3746 // Utility functions. 3747 //===----------------------------------------------------------------------===// 3748 3749 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 3750 StringRef name, StringRef value, 3751 LLVM::Linkage linkage) { 3752 assert(builder.getInsertionBlock() && 3753 builder.getInsertionBlock()->getParentOp() && 3754 "expected builder to point to a block constrained in an op"); 3755 auto module = 3756 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 3757 assert(module && "builder points to an op outside of a module"); 3758 3759 // Create the global at the entry of the module. 3760 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 3761 MLIRContext *ctx = builder.getContext(); 3762 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 3763 auto global = moduleBuilder.create<LLVM::GlobalOp>( 3764 loc, type, /*isConstant=*/true, linkage, name, 3765 builder.getStringAttr(value), /*alignment=*/0); 3766 3767 LLVMPointerType ptrType = LLVMPointerType::get(ctx); 3768 // Get the pointer to the first character in the global string. 3769 Value globalPtr = 3770 builder.create<LLVM::AddressOfOp>(loc, ptrType, global.getSymNameAttr()); 3771 return builder.create<LLVM::GEPOp>(loc, ptrType, type, globalPtr, 3772 ArrayRef<GEPArg>{0, 0}); 3773 } 3774 3775 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 3776 return op->hasTrait<OpTrait::SymbolTable>() && 3777 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 3778 } 3779