1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the SPIR-V dialect in MLIR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 14 15 #include "SPIRVParsingUtils.h" 16 17 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 21 #include "mlir/Dialect/UB/IR/UBOps.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/BuiltinTypes.h" 24 #include "mlir/IR/DialectImplementation.h" 25 #include "mlir/IR/MLIRContext.h" 26 #include "mlir/Parser/Parser.h" 27 #include "mlir/Transforms/InliningUtils.h" 28 #include "llvm/ADT/DenseMap.h" 29 #include "llvm/ADT/Sequence.h" 30 #include "llvm/ADT/SetVector.h" 31 #include "llvm/ADT/StringExtras.h" 32 #include "llvm/ADT/StringMap.h" 33 #include "llvm/ADT/TypeSwitch.h" 34 #include "llvm/Support/raw_ostream.h" 35 36 using namespace mlir; 37 using namespace mlir::spirv; 38 39 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc" 40 41 //===----------------------------------------------------------------------===// 42 // InlinerInterface 43 //===----------------------------------------------------------------------===// 44 45 /// Returns true if the given region contains spirv.Return or spirv.ReturnValue 46 /// ops. 47 static inline bool containsReturn(Region ®ion) { 48 return llvm::any_of(region, [](Block &block) { 49 Operation *terminator = block.getTerminator(); 50 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator); 51 }); 52 } 53 54 namespace { 55 /// This class defines the interface for inlining within the SPIR-V dialect. 56 struct SPIRVInlinerInterface : public DialectInlinerInterface { 57 using DialectInlinerInterface::DialectInlinerInterface; 58 59 /// All call operations within SPIRV can be inlined. 60 bool isLegalToInline(Operation *call, Operation *callable, 61 bool wouldBeCloned) const final { 62 return true; 63 } 64 65 /// Returns true if the given region 'src' can be inlined into the region 66 /// 'dest' that is attached to an operation registered to the current dialect. 67 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 68 IRMapping &) const final { 69 // Return true here when inlining into spirv.func, spirv.mlir.selection, and 70 // spirv.mlir.loop operations. 71 auto *op = dest->getParentOp(); 72 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op); 73 } 74 75 /// Returns true if the given operation 'op', that is registered to this 76 /// dialect, can be inlined into the region 'dest' that is attached to an 77 /// operation registered to the current dialect. 78 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 79 IRMapping &) const final { 80 // TODO: Enable inlining structured control flows with return. 81 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) && 82 containsReturn(op->getRegion(0))) 83 return false; 84 // TODO: we need to filter OpKill here to avoid inlining it to 85 // a loop continue construct: 86 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86 87 // However OpKill is fragment shader specific and we don't support it yet. 88 return true; 89 } 90 91 /// Handle the given inlined terminator by replacing it with a new operation 92 /// as necessary. 93 void handleTerminator(Operation *op, Block *newDest) const final { 94 if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) { 95 OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest); 96 op->erase(); 97 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) { 98 OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest, 99 retValOp->getOperands()); 100 op->erase(); 101 } 102 } 103 104 /// Handle the given inlined terminator by replacing it with a new operation 105 /// as necessary. 106 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { 107 // Only spirv.ReturnValue needs to be handled here. 108 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op); 109 if (!retValOp) 110 return; 111 112 // Replace the values directly with the return operands. 113 assert(valuesToRepl.size() == 1 && 114 "spirv.ReturnValue expected to only handle one result"); 115 valuesToRepl.front().replaceAllUsesWith(retValOp.getValue()); 116 } 117 }; 118 } // namespace 119 120 //===----------------------------------------------------------------------===// 121 // SPIR-V Dialect 122 //===----------------------------------------------------------------------===// 123 124 void SPIRVDialect::initialize() { 125 registerAttributes(); 126 registerTypes(); 127 128 // Add SPIR-V ops. 129 addOperations< 130 #define GET_OP_LIST 131 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" 132 >(); 133 134 addInterfaces<SPIRVInlinerInterface>(); 135 136 // Allow unknown operations because SPIR-V is extensible. 137 allowUnknownOperations(); 138 declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>(); 139 } 140 141 std::string SPIRVDialect::getAttributeName(Decoration decoration) { 142 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)); 143 } 144 145 //===----------------------------------------------------------------------===// 146 // Type Parsing 147 //===----------------------------------------------------------------------===// 148 149 // Forward declarations. 150 template <typename ValTy> 151 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, 152 DialectAsmParser &parser); 153 template <> 154 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, 155 DialectAsmParser &parser); 156 157 template <> 158 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, 159 DialectAsmParser &parser); 160 161 static Type parseAndVerifyType(SPIRVDialect const &dialect, 162 DialectAsmParser &parser) { 163 Type type; 164 SMLoc typeLoc = parser.getCurrentLocation(); 165 if (parser.parseType(type)) 166 return Type(); 167 168 // Allow SPIR-V dialect types 169 if (&type.getDialect() == &dialect) 170 return type; 171 172 // Check other allowed types 173 if (auto t = llvm::dyn_cast<FloatType>(type)) { 174 if (type.isBF16()) { 175 parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); 176 return Type(); 177 } 178 } else if (auto t = llvm::dyn_cast<IntegerType>(type)) { 179 if (!ScalarType::isValid(t)) { 180 parser.emitError(typeLoc, 181 "only 1/8/16/32/64-bit integer type allowed but found ") 182 << type; 183 return Type(); 184 } 185 } else if (auto t = llvm::dyn_cast<VectorType>(type)) { 186 if (t.getRank() != 1) { 187 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; 188 return Type(); 189 } 190 if (t.getNumElements() > 4) { 191 parser.emitError( 192 typeLoc, "vector length has to be less than or equal to 4 but found ") 193 << t.getNumElements(); 194 return Type(); 195 } 196 } else { 197 parser.emitError(typeLoc, "cannot use ") 198 << type << " to compose SPIR-V types"; 199 return Type(); 200 } 201 202 return type; 203 } 204 205 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, 206 DialectAsmParser &parser) { 207 Type type; 208 SMLoc typeLoc = parser.getCurrentLocation(); 209 if (parser.parseType(type)) 210 return Type(); 211 212 if (auto t = llvm::dyn_cast<VectorType>(type)) { 213 if (t.getRank() != 1) { 214 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; 215 return Type(); 216 } 217 if (t.getNumElements() > 4 || t.getNumElements() < 2) { 218 parser.emitError(typeLoc, 219 "matrix columns size has to be less than or equal " 220 "to 4 and greater than or equal 2, but found ") 221 << t.getNumElements(); 222 return Type(); 223 } 224 225 if (!llvm::isa<FloatType>(t.getElementType())) { 226 parser.emitError(typeLoc, "matrix columns' elements must be of " 227 "Float type, got ") 228 << t.getElementType(); 229 return Type(); 230 } 231 } else { 232 parser.emitError(typeLoc, "matrix must be composed using vector " 233 "type, got ") 234 << type; 235 return Type(); 236 } 237 238 return type; 239 } 240 241 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, 242 DialectAsmParser &parser) { 243 Type type; 244 SMLoc typeLoc = parser.getCurrentLocation(); 245 if (parser.parseType(type)) 246 return Type(); 247 248 if (!llvm::isa<ImageType>(type)) { 249 parser.emitError(typeLoc, 250 "sampled image must be composed using image type, got ") 251 << type; 252 return Type(); 253 } 254 255 return type; 256 } 257 258 /// Parses an optional `, stride = N` assembly segment. If no parsing failure 259 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if 260 /// missing. 261 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, 262 DialectAsmParser &parser, 263 unsigned &stride) { 264 if (failed(parser.parseOptionalComma())) { 265 stride = 0; 266 return success(); 267 } 268 269 if (parser.parseKeyword("stride") || parser.parseEqual()) 270 return failure(); 271 272 SMLoc strideLoc = parser.getCurrentLocation(); 273 std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser); 274 if (!optStride) 275 return failure(); 276 277 if (!(stride = *optStride)) { 278 parser.emitError(strideLoc, "ArrayStride must be greater than zero"); 279 return failure(); 280 } 281 return success(); 282 } 283 284 // element-type ::= integer-type 285 // | floating-point-type 286 // | vector-type 287 // | spirv-type 288 // 289 // array-type ::= `!spirv.array` `<` integer-literal `x` element-type 290 // (`,` `stride` `=` integer-literal)? `>` 291 static Type parseArrayType(SPIRVDialect const &dialect, 292 DialectAsmParser &parser) { 293 if (parser.parseLess()) 294 return Type(); 295 296 SmallVector<int64_t, 1> countDims; 297 SMLoc countLoc = parser.getCurrentLocation(); 298 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false)) 299 return Type(); 300 if (countDims.size() != 1) { 301 parser.emitError(countLoc, 302 "expected single integer for array element count"); 303 return Type(); 304 } 305 306 // According to the SPIR-V spec: 307 // "Length is the number of elements in the array. It must be at least 1." 308 int64_t count = countDims[0]; 309 if (count == 0) { 310 parser.emitError(countLoc, "expected array length greater than 0"); 311 return Type(); 312 } 313 314 Type elementType = parseAndVerifyType(dialect, parser); 315 if (!elementType) 316 return Type(); 317 318 unsigned stride = 0; 319 if (failed(parseOptionalArrayStride(dialect, parser, stride))) 320 return Type(); 321 322 if (parser.parseGreater()) 323 return Type(); 324 return ArrayType::get(elementType, count, stride); 325 } 326 327 // cooperative-matrix-type ::= 328 // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,` 329 // scope `,` use `>` 330 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, 331 DialectAsmParser &parser) { 332 if (parser.parseLess()) 333 return {}; 334 335 SmallVector<int64_t, 2> dims; 336 SMLoc countLoc = parser.getCurrentLocation(); 337 if (parser.parseDimensionList(dims, /*allowDynamic=*/false)) 338 return {}; 339 340 if (dims.size() != 2) { 341 parser.emitError(countLoc, "expected row and column count"); 342 return {}; 343 } 344 345 auto elementTy = parseAndVerifyType(dialect, parser); 346 if (!elementTy) 347 return {}; 348 349 Scope scope; 350 if (parser.parseComma() || 351 spirv::parseEnumKeywordAttr(scope, parser, "scope <id>")) 352 return {}; 353 354 CooperativeMatrixUseKHR use; 355 if (parser.parseComma() || 356 spirv::parseEnumKeywordAttr(use, parser, "use <id>")) 357 return {}; 358 359 if (parser.parseGreater()) 360 return {}; 361 362 return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use); 363 } 364 365 // TODO: Reorder methods to be utilities first and parse*Type 366 // methods in alphabetical order 367 // 368 // storage-class ::= `UniformConstant` 369 // | `Uniform` 370 // | `Workgroup` 371 // | <and other storage classes...> 372 // 373 // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>` 374 static Type parsePointerType(SPIRVDialect const &dialect, 375 DialectAsmParser &parser) { 376 if (parser.parseLess()) 377 return Type(); 378 379 auto pointeeType = parseAndVerifyType(dialect, parser); 380 if (!pointeeType) 381 return Type(); 382 383 StringRef storageClassSpec; 384 SMLoc storageClassLoc = parser.getCurrentLocation(); 385 if (parser.parseComma() || parser.parseKeyword(&storageClassSpec)) 386 return Type(); 387 388 auto storageClass = symbolizeStorageClass(storageClassSpec); 389 if (!storageClass) { 390 parser.emitError(storageClassLoc, "unknown storage class: ") 391 << storageClassSpec; 392 return Type(); 393 } 394 if (parser.parseGreater()) 395 return Type(); 396 return PointerType::get(pointeeType, *storageClass); 397 } 398 399 // runtime-array-type ::= `!spirv.rtarray` `<` element-type 400 // (`,` `stride` `=` integer-literal)? `>` 401 static Type parseRuntimeArrayType(SPIRVDialect const &dialect, 402 DialectAsmParser &parser) { 403 if (parser.parseLess()) 404 return Type(); 405 406 Type elementType = parseAndVerifyType(dialect, parser); 407 if (!elementType) 408 return Type(); 409 410 unsigned stride = 0; 411 if (failed(parseOptionalArrayStride(dialect, parser, stride))) 412 return Type(); 413 414 if (parser.parseGreater()) 415 return Type(); 416 return RuntimeArrayType::get(elementType, stride); 417 } 418 419 // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>` 420 static Type parseMatrixType(SPIRVDialect const &dialect, 421 DialectAsmParser &parser) { 422 if (parser.parseLess()) 423 return Type(); 424 425 SmallVector<int64_t, 1> countDims; 426 SMLoc countLoc = parser.getCurrentLocation(); 427 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false)) 428 return Type(); 429 if (countDims.size() != 1) { 430 parser.emitError(countLoc, "expected single unsigned " 431 "integer for number of columns"); 432 return Type(); 433 } 434 435 int64_t columnCount = countDims[0]; 436 // According to the specification, Matrices can have 2, 3, or 4 columns 437 if (columnCount < 2 || columnCount > 4) { 438 parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 " 439 "columns"); 440 return Type(); 441 } 442 443 Type columnType = parseAndVerifyMatrixType(dialect, parser); 444 if (!columnType) 445 return Type(); 446 447 if (parser.parseGreater()) 448 return Type(); 449 450 return MatrixType::get(columnType, columnCount); 451 } 452 453 // Specialize this function to parse each of the parameters that define an 454 // ImageType. By default it assumes this is an enum type. 455 template <typename ValTy> 456 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, 457 DialectAsmParser &parser) { 458 StringRef enumSpec; 459 SMLoc enumLoc = parser.getCurrentLocation(); 460 if (parser.parseKeyword(&enumSpec)) { 461 return std::nullopt; 462 } 463 464 auto val = spirv::symbolizeEnum<ValTy>(enumSpec); 465 if (!val) 466 parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'"; 467 return val; 468 } 469 470 template <> 471 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, 472 DialectAsmParser &parser) { 473 // TODO: Further verify that the element type can be sampled 474 auto ty = parseAndVerifyType(dialect, parser); 475 if (!ty) 476 return std::nullopt; 477 return ty; 478 } 479 480 template <typename IntTy> 481 static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect, 482 DialectAsmParser &parser) { 483 IntTy offsetVal = std::numeric_limits<IntTy>::max(); 484 if (parser.parseInteger(offsetVal)) 485 return std::nullopt; 486 return offsetVal; 487 } 488 489 template <> 490 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, 491 DialectAsmParser &parser) { 492 return parseAndVerifyInteger<unsigned>(dialect, parser); 493 } 494 495 namespace { 496 // Functor object to parse a comma separated list of specs. The function 497 // parseAndVerify does the actual parsing and verification of individual 498 // elements. This is a functor since parsing the last element of the list 499 // (termination condition) needs partial specialization. 500 template <typename ParseType, typename... Args> 501 struct ParseCommaSeparatedList { 502 std::optional<std::tuple<ParseType, Args...>> 503 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { 504 auto parseVal = parseAndVerify<ParseType>(dialect, parser); 505 if (!parseVal) 506 return std::nullopt; 507 508 auto numArgs = std::tuple_size<std::tuple<Args...>>::value; 509 if (numArgs != 0 && failed(parser.parseComma())) 510 return std::nullopt; 511 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser); 512 if (!remainingValues) 513 return std::nullopt; 514 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()), 515 remainingValues.value()); 516 } 517 }; 518 519 // Partial specialization of the function to parse a comma separated list of 520 // specs to parse the last element of the list. 521 template <typename ParseType> 522 struct ParseCommaSeparatedList<ParseType> { 523 std::optional<std::tuple<ParseType>> 524 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { 525 if (auto value = parseAndVerify<ParseType>(dialect, parser)) 526 return std::tuple<ParseType>(*value); 527 return std::nullopt; 528 } 529 }; 530 } // namespace 531 532 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...> 533 // 534 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown` 535 // 536 // arrayed-info ::= `NonArrayed` | `Arrayed` 537 // 538 // sampling-info ::= `SingleSampled` | `MultiSampled` 539 // 540 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler` 541 // 542 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...> 543 // 544 // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,` 545 // arrayed-info `,` sampling-info `,` 546 // sampler-use-info `,` format `>` 547 static Type parseImageType(SPIRVDialect const &dialect, 548 DialectAsmParser &parser) { 549 if (parser.parseLess()) 550 return Type(); 551 552 auto value = 553 ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 554 ImageSamplingInfo, ImageSamplerUseInfo, 555 ImageFormat>{}(dialect, parser); 556 if (!value) 557 return Type(); 558 559 if (parser.parseGreater()) 560 return Type(); 561 return ImageType::get(*value); 562 } 563 564 // sampledImage-type :: = `!spirv.sampledImage<` image-type `>` 565 static Type parseSampledImageType(SPIRVDialect const &dialect, 566 DialectAsmParser &parser) { 567 if (parser.parseLess()) 568 return Type(); 569 570 Type parsedType = parseAndVerifySampledImageType(dialect, parser); 571 if (!parsedType) 572 return Type(); 573 574 if (parser.parseGreater()) 575 return Type(); 576 return SampledImageType::get(parsedType); 577 } 578 579 // Parse decorations associated with a member. 580 static ParseResult parseStructMemberDecorations( 581 SPIRVDialect const &dialect, DialectAsmParser &parser, 582 ArrayRef<Type> memberTypes, 583 SmallVectorImpl<StructType::OffsetInfo> &offsetInfo, 584 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) { 585 586 // Check if the first element is offset. 587 SMLoc offsetLoc = parser.getCurrentLocation(); 588 StructType::OffsetInfo offset = 0; 589 OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset); 590 if (offsetParseResult.has_value()) { 591 if (failed(*offsetParseResult)) 592 return failure(); 593 594 if (offsetInfo.size() != memberTypes.size() - 1) { 595 return parser.emitError(offsetLoc, 596 "offset specification must be given for " 597 "all members"); 598 } 599 offsetInfo.push_back(offset); 600 } 601 602 // Check for no spirv::Decorations. 603 if (succeeded(parser.parseOptionalRSquare())) 604 return success(); 605 606 // If there was an offset, make sure to parse the comma. 607 if (offsetParseResult.has_value() && parser.parseComma()) 608 return failure(); 609 610 // Check for spirv::Decorations. 611 auto parseDecorations = [&]() { 612 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser); 613 if (!memberDecoration) 614 return failure(); 615 616 // Parse member decoration value if it exists. 617 if (succeeded(parser.parseOptionalEqual())) { 618 auto memberDecorationValue = 619 parseAndVerifyInteger<uint32_t>(dialect, parser); 620 621 if (!memberDecorationValue) 622 return failure(); 623 624 memberDecorationInfo.emplace_back( 625 static_cast<uint32_t>(memberTypes.size() - 1), 1, 626 memberDecoration.value(), memberDecorationValue.value()); 627 } else { 628 memberDecorationInfo.emplace_back( 629 static_cast<uint32_t>(memberTypes.size() - 1), 0, 630 memberDecoration.value(), 0); 631 } 632 return success(); 633 }; 634 if (failed(parser.parseCommaSeparatedList(parseDecorations)) || 635 failed(parser.parseRSquare())) 636 return failure(); 637 638 return success(); 639 } 640 641 // struct-member-decoration ::= integer-literal? spirv-decoration* 642 // struct-type ::= 643 // `!spirv.struct<` (id `,`)? 644 // `(` 645 // (spirv-type (`[` struct-member-decoration `]`)?)* 646 // `)>` 647 static Type parseStructType(SPIRVDialect const &dialect, 648 DialectAsmParser &parser) { 649 // TODO: This function is quite lengthy. Break it down into smaller chunks. 650 651 if (parser.parseLess()) 652 return Type(); 653 654 StringRef identifier; 655 FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse; 656 657 // Check if this is an identified struct type. 658 if (succeeded(parser.parseOptionalKeyword(&identifier))) { 659 // Check if this is a possible recursive reference. 660 auto structType = 661 StructType::getIdentified(dialect.getContext(), identifier); 662 cyclicParse = parser.tryStartCyclicParse(structType); 663 if (succeeded(parser.parseOptionalGreater())) { 664 if (succeeded(cyclicParse)) { 665 parser.emitError( 666 parser.getNameLoc(), 667 "recursive struct reference not nested in struct definition"); 668 669 return Type(); 670 } 671 672 return structType; 673 } 674 675 if (failed(parser.parseComma())) 676 return Type(); 677 678 if (failed(cyclicParse)) { 679 parser.emitError(parser.getNameLoc(), 680 "identifier already used for an enclosing struct"); 681 return Type(); 682 } 683 } 684 685 if (failed(parser.parseLParen())) 686 return Type(); 687 688 if (succeeded(parser.parseOptionalRParen()) && 689 succeeded(parser.parseOptionalGreater())) { 690 return StructType::getEmpty(dialect.getContext(), identifier); 691 } 692 693 StructType idStructTy; 694 695 if (!identifier.empty()) 696 idStructTy = StructType::getIdentified(dialect.getContext(), identifier); 697 698 SmallVector<Type, 4> memberTypes; 699 SmallVector<StructType::OffsetInfo, 4> offsetInfo; 700 SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo; 701 702 do { 703 Type memberType; 704 if (parser.parseType(memberType)) 705 return Type(); 706 memberTypes.push_back(memberType); 707 708 if (succeeded(parser.parseOptionalLSquare())) 709 if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo, 710 memberDecorationInfo)) 711 return Type(); 712 } while (succeeded(parser.parseOptionalComma())); 713 714 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) { 715 parser.emitError(parser.getNameLoc(), 716 "offset specification must be given for all members"); 717 return Type(); 718 } 719 720 if (failed(parser.parseRParen()) || failed(parser.parseGreater())) 721 return Type(); 722 723 if (!identifier.empty()) { 724 if (failed(idStructTy.trySetBody(memberTypes, offsetInfo, 725 memberDecorationInfo))) 726 return Type(); 727 return idStructTy; 728 } 729 730 return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); 731 } 732 733 // spirv-type ::= array-type 734 // | element-type 735 // | image-type 736 // | pointer-type 737 // | runtime-array-type 738 // | sampled-image-type 739 // | struct-type 740 Type SPIRVDialect::parseType(DialectAsmParser &parser) const { 741 StringRef keyword; 742 if (parser.parseKeyword(&keyword)) 743 return Type(); 744 745 if (keyword == "array") 746 return parseArrayType(*this, parser); 747 if (keyword == "coopmatrix") 748 return parseCooperativeMatrixType(*this, parser); 749 if (keyword == "image") 750 return parseImageType(*this, parser); 751 if (keyword == "ptr") 752 return parsePointerType(*this, parser); 753 if (keyword == "rtarray") 754 return parseRuntimeArrayType(*this, parser); 755 if (keyword == "sampled_image") 756 return parseSampledImageType(*this, parser); 757 if (keyword == "struct") 758 return parseStructType(*this, parser); 759 if (keyword == "matrix") 760 return parseMatrixType(*this, parser); 761 parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword; 762 return Type(); 763 } 764 765 //===----------------------------------------------------------------------===// 766 // Type Printing 767 //===----------------------------------------------------------------------===// 768 769 static void print(ArrayType type, DialectAsmPrinter &os) { 770 os << "array<" << type.getNumElements() << " x " << type.getElementType(); 771 if (unsigned stride = type.getArrayStride()) 772 os << ", stride=" << stride; 773 os << ">"; 774 } 775 776 static void print(RuntimeArrayType type, DialectAsmPrinter &os) { 777 os << "rtarray<" << type.getElementType(); 778 if (unsigned stride = type.getArrayStride()) 779 os << ", stride=" << stride; 780 os << ">"; 781 } 782 783 static void print(PointerType type, DialectAsmPrinter &os) { 784 os << "ptr<" << type.getPointeeType() << ", " 785 << stringifyStorageClass(type.getStorageClass()) << ">"; 786 } 787 788 static void print(ImageType type, DialectAsmPrinter &os) { 789 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim()) 790 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", " 791 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", " 792 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", " 793 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", " 794 << stringifyImageFormat(type.getImageFormat()) << ">"; 795 } 796 797 static void print(SampledImageType type, DialectAsmPrinter &os) { 798 os << "sampled_image<" << type.getImageType() << ">"; 799 } 800 801 static void print(StructType type, DialectAsmPrinter &os) { 802 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint; 803 804 os << "struct<"; 805 806 if (type.isIdentified()) { 807 os << type.getIdentifier(); 808 809 cyclicPrint = os.tryStartCyclicPrint(type); 810 if (failed(cyclicPrint)) { 811 os << ">"; 812 return; 813 } 814 815 os << ", "; 816 } 817 818 os << "("; 819 820 auto printMember = [&](unsigned i) { 821 os << type.getElementType(i); 822 SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations; 823 type.getMemberDecorations(i, decorations); 824 if (type.hasOffset() || !decorations.empty()) { 825 os << " ["; 826 if (type.hasOffset()) { 827 os << type.getMemberOffset(i); 828 if (!decorations.empty()) 829 os << ", "; 830 } 831 auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) { 832 os << stringifyDecoration(decoration.decoration); 833 if (decoration.hasValue) { 834 os << "=" << decoration.decorationValue; 835 } 836 }; 837 llvm::interleaveComma(decorations, os, eachFn); 838 os << "]"; 839 } 840 }; 841 llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os, 842 printMember); 843 os << ")>"; 844 } 845 846 static void print(CooperativeMatrixType type, DialectAsmPrinter &os) { 847 os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x" 848 << type.getElementType() << ", " << type.getScope() << ", " 849 << type.getUse() << ">"; 850 } 851 852 static void print(MatrixType type, DialectAsmPrinter &os) { 853 os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType(); 854 os << ">"; 855 } 856 857 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { 858 TypeSwitch<Type>(type) 859 .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType, 860 ImageType, SampledImageType, StructType, MatrixType>( 861 [&](auto type) { print(type, os); }) 862 .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); 863 } 864 865 //===----------------------------------------------------------------------===// 866 // Constant 867 //===----------------------------------------------------------------------===// 868 869 Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, 870 Attribute value, Type type, 871 Location loc) { 872 if (auto poison = dyn_cast<ub::PoisonAttr>(value)) 873 return builder.create<ub::PoisonOp>(loc, type, poison); 874 875 if (!spirv::ConstantOp::isBuildableWith(type)) 876 return nullptr; 877 878 return builder.create<spirv::ConstantOp>(loc, type, value); 879 } 880 881 //===----------------------------------------------------------------------===// 882 // Shader Interface ABI 883 //===----------------------------------------------------------------------===// 884 885 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, 886 NamedAttribute attribute) { 887 StringRef symbol = attribute.getName().strref(); 888 Attribute attr = attribute.getValue(); 889 890 if (symbol == spirv::getEntryPointABIAttrName()) { 891 if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) { 892 return op->emitError("'") 893 << symbol << "' attribute must be an entry point ABI attribute"; 894 } 895 } else if (symbol == spirv::getTargetEnvAttrName()) { 896 if (!llvm::isa<spirv::TargetEnvAttr>(attr)) 897 return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr"; 898 } else { 899 return op->emitError("found unsupported '") 900 << symbol << "' attribute on operation"; 901 } 902 903 return success(); 904 } 905 906 /// Verifies the given SPIR-V `attribute` attached to a value of the given 907 /// `valueType` is valid. 908 static LogicalResult verifyRegionAttribute(Location loc, Type valueType, 909 NamedAttribute attribute) { 910 StringRef symbol = attribute.getName().strref(); 911 Attribute attr = attribute.getValue(); 912 913 if (symbol == spirv::getInterfaceVarABIAttrName()) { 914 auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr); 915 if (!varABIAttr) 916 return emitError(loc, "'") 917 << symbol << "' must be a spirv::InterfaceVarABIAttr"; 918 919 if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat()) 920 return emitError(loc, "'") << symbol 921 << "' attribute cannot specify storage class " 922 "when attaching to a non-scalar value"; 923 return success(); 924 } 925 if (symbol == spirv::DecorationAttr::name) { 926 if (!isa<spirv::DecorationAttr>(attr)) 927 return emitError(loc, "'") 928 << symbol << "' must be a spirv::DecorationAttr"; 929 return success(); 930 } 931 932 return emitError(loc, "found unsupported '") 933 << symbol << "' attribute on region argument"; 934 } 935 936 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op, 937 unsigned regionIndex, 938 unsigned argIndex, 939 NamedAttribute attribute) { 940 auto funcOp = dyn_cast<FunctionOpInterface>(op); 941 if (!funcOp) 942 return success(); 943 Type argType = funcOp.getArgumentTypes()[argIndex]; 944 945 return verifyRegionAttribute(op->getLoc(), argType, attribute); 946 } 947 948 LogicalResult SPIRVDialect::verifyRegionResultAttribute( 949 Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/, 950 NamedAttribute attribute) { 951 return op->emitError("cannot attach SPIR-V attributes to region result"); 952 } 953