xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (revision b719ab4eef634f24605ca7ccd4874338c34e05bd)
101178654SLei Zhang //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
201178654SLei Zhang //
301178654SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM
401178654SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
501178654SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
601178654SLei Zhang //
701178654SLei Zhang //===----------------------------------------------------------------------===//
801178654SLei Zhang //
901178654SLei Zhang // This file defines the SPIR-V dialect in MLIR.
1001178654SLei Zhang //
1101178654SLei Zhang //===----------------------------------------------------------------------===//
1201178654SLei Zhang 
1301178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
149415241cSJakub Kuderski 
159415241cSJakub Kuderski #include "SPIRVParsingUtils.h"
169415241cSJakub Kuderski 
172dace045SSang Ik Lee #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
1801178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1901178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
2001178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
215dce7481SIvan Butygin #include "mlir/Dialect/UB/IR/UBOps.h"
2201178654SLei Zhang #include "mlir/IR/Builders.h"
2301178654SLei Zhang #include "mlir/IR/BuiltinTypes.h"
2401178654SLei Zhang #include "mlir/IR/DialectImplementation.h"
2501178654SLei Zhang #include "mlir/IR/MLIRContext.h"
269eaff423SRiver Riddle #include "mlir/Parser/Parser.h"
2701178654SLei Zhang #include "mlir/Transforms/InliningUtils.h"
2801178654SLei Zhang #include "llvm/ADT/DenseMap.h"
2901178654SLei Zhang #include "llvm/ADT/Sequence.h"
3001178654SLei Zhang #include "llvm/ADT/SetVector.h"
3101178654SLei Zhang #include "llvm/ADT/StringExtras.h"
3201178654SLei Zhang #include "llvm/ADT/StringMap.h"
3301178654SLei Zhang #include "llvm/ADT/TypeSwitch.h"
3401178654SLei Zhang #include "llvm/Support/raw_ostream.h"
3501178654SLei Zhang 
3601178654SLei Zhang using namespace mlir;
3701178654SLei Zhang using namespace mlir::spirv;
3801178654SLei Zhang 
39485cc55eSStella Laurenzo #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
40485cc55eSStella Laurenzo 
4101178654SLei Zhang //===----------------------------------------------------------------------===//
4201178654SLei Zhang // InlinerInterface
4301178654SLei Zhang //===----------------------------------------------------------------------===//
4401178654SLei Zhang 
455ab6ef75SJakub Kuderski /// Returns true if the given region contains spirv.Return or spirv.ReturnValue
465ab6ef75SJakub Kuderski /// ops.
4701178654SLei Zhang static inline bool containsReturn(Region &region) {
4801178654SLei Zhang   return llvm::any_of(region, [](Block &block) {
4901178654SLei Zhang     Operation *terminator = block.getTerminator();
5001178654SLei Zhang     return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
5101178654SLei Zhang   });
5201178654SLei Zhang }
5301178654SLei Zhang 
5401178654SLei Zhang namespace {
5501178654SLei Zhang /// This class defines the interface for inlining within the SPIR-V dialect.
5601178654SLei Zhang struct SPIRVInlinerInterface : public DialectInlinerInterface {
5701178654SLei Zhang   using DialectInlinerInterface::DialectInlinerInterface;
5801178654SLei Zhang 
5901178654SLei Zhang   /// All call operations within SPIRV can be inlined.
6001178654SLei Zhang   bool isLegalToInline(Operation *call, Operation *callable,
6101178654SLei Zhang                        bool wouldBeCloned) const final {
6201178654SLei Zhang     return true;
6301178654SLei Zhang   }
6401178654SLei Zhang 
6501178654SLei Zhang   /// Returns true if the given region 'src' can be inlined into the region
6601178654SLei Zhang   /// 'dest' that is attached to an operation registered to the current dialect.
6701178654SLei Zhang   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
684d67b278SJeff Niu                        IRMapping &) const final {
695ab6ef75SJakub Kuderski     // Return true here when inlining into spirv.func, spirv.mlir.selection, and
705ab6ef75SJakub Kuderski     // spirv.mlir.loop operations.
7101178654SLei Zhang     auto *op = dest->getParentOp();
7201178654SLei Zhang     return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
7301178654SLei Zhang   }
7401178654SLei Zhang 
7501178654SLei Zhang   /// Returns true if the given operation 'op', that is registered to this
7601178654SLei Zhang   /// dialect, can be inlined into the region 'dest' that is attached to an
7701178654SLei Zhang   /// operation registered to the current dialect.
7801178654SLei Zhang   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
794d67b278SJeff Niu                        IRMapping &) const final {
8001178654SLei Zhang     // TODO: Enable inlining structured control flows with return.
8101178654SLei Zhang     if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
8201178654SLei Zhang         containsReturn(op->getRegion(0)))
8301178654SLei Zhang       return false;
8401178654SLei Zhang     // TODO: we need to filter OpKill here to avoid inlining it to
8501178654SLei Zhang     // a loop continue construct:
8601178654SLei Zhang     // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
8701178654SLei Zhang     // However OpKill is fragment shader specific and we don't support it yet.
8801178654SLei Zhang     return true;
8901178654SLei Zhang   }
9001178654SLei Zhang 
9101178654SLei Zhang   /// Handle the given inlined terminator by replacing it with a new operation
9201178654SLei Zhang   /// as necessary.
9301178654SLei Zhang   void handleTerminator(Operation *op, Block *newDest) const final {
9401178654SLei Zhang     if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
9501178654SLei Zhang       OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
9601178654SLei Zhang       op->erase();
9701178654SLei Zhang     } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
989a5fb74fSArtem Tyurin       OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest,
999a5fb74fSArtem Tyurin                                             retValOp->getOperands());
1009a5fb74fSArtem Tyurin       op->erase();
10101178654SLei Zhang     }
10201178654SLei Zhang   }
10301178654SLei Zhang 
10401178654SLei Zhang   /// Handle the given inlined terminator by replacing it with a new operation
10501178654SLei Zhang   /// as necessary.
10626a0b277SMehdi Amini   void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
1075ab6ef75SJakub Kuderski     // Only spirv.ReturnValue needs to be handled here.
10801178654SLei Zhang     auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
10901178654SLei Zhang     if (!retValOp)
11001178654SLei Zhang       return;
11101178654SLei Zhang 
11201178654SLei Zhang     // Replace the values directly with the return operands.
11301178654SLei Zhang     assert(valuesToRepl.size() == 1 &&
1145ab6ef75SJakub Kuderski            "spirv.ReturnValue expected to only handle one result");
11590a1632dSJakub Kuderski     valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
11601178654SLei Zhang   }
11701178654SLei Zhang };
11801178654SLei Zhang } // namespace
11901178654SLei Zhang 
12001178654SLei Zhang //===----------------------------------------------------------------------===//
12101178654SLei Zhang // SPIR-V Dialect
12201178654SLei Zhang //===----------------------------------------------------------------------===//
12301178654SLei Zhang 
12401178654SLei Zhang void SPIRVDialect::initialize() {
12531bb8efdSRiver Riddle   registerAttributes();
12631bb8efdSRiver Riddle   registerTypes();
12701178654SLei Zhang 
12801178654SLei Zhang   // Add SPIR-V ops.
12901178654SLei Zhang   addOperations<
13001178654SLei Zhang #define GET_OP_LIST
13101178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
13201178654SLei Zhang       >();
13301178654SLei Zhang 
13401178654SLei Zhang   addInterfaces<SPIRVInlinerInterface>();
13501178654SLei Zhang 
13601178654SLei Zhang   // Allow unknown operations because SPIR-V is extensible.
13701178654SLei Zhang   allowUnknownOperations();
13835d55f28SJustin Fargnoli   declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
13901178654SLei Zhang }
14001178654SLei Zhang 
14101178654SLei Zhang std::string SPIRVDialect::getAttributeName(Decoration decoration) {
14201178654SLei Zhang   return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
14301178654SLei Zhang }
14401178654SLei Zhang 
14501178654SLei Zhang //===----------------------------------------------------------------------===//
14601178654SLei Zhang // Type Parsing
14701178654SLei Zhang //===----------------------------------------------------------------------===//
14801178654SLei Zhang 
14901178654SLei Zhang // Forward declarations.
15001178654SLei Zhang template <typename ValTy>
15122426110SRamkumar Ramachandra static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
15201178654SLei Zhang                                            DialectAsmParser &parser);
15301178654SLei Zhang template <>
15422426110SRamkumar Ramachandra std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
15501178654SLei Zhang                                          DialectAsmParser &parser);
15601178654SLei Zhang 
15701178654SLei Zhang template <>
15822426110SRamkumar Ramachandra std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
15901178654SLei Zhang                                                  DialectAsmParser &parser);
16001178654SLei Zhang 
16101178654SLei Zhang static Type parseAndVerifyType(SPIRVDialect const &dialect,
16201178654SLei Zhang                                DialectAsmParser &parser) {
16301178654SLei Zhang   Type type;
1646842ec42SRiver Riddle   SMLoc typeLoc = parser.getCurrentLocation();
16501178654SLei Zhang   if (parser.parseType(type))
16601178654SLei Zhang     return Type();
16701178654SLei Zhang 
16801178654SLei Zhang   // Allow SPIR-V dialect types
16901178654SLei Zhang   if (&type.getDialect() == &dialect)
17001178654SLei Zhang     return type;
17101178654SLei Zhang 
17201178654SLei Zhang   // Check other allowed types
173c1fa60b4STres Popp   if (auto t = llvm::dyn_cast<FloatType>(type)) {
17401178654SLei Zhang     if (type.isBF16()) {
17501178654SLei Zhang       parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
17601178654SLei Zhang       return Type();
17701178654SLei Zhang     }
178c1fa60b4STres Popp   } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
17901178654SLei Zhang     if (!ScalarType::isValid(t)) {
18001178654SLei Zhang       parser.emitError(typeLoc,
18101178654SLei Zhang                        "only 1/8/16/32/64-bit integer type allowed but found ")
18201178654SLei Zhang           << type;
18301178654SLei Zhang       return Type();
18401178654SLei Zhang     }
185c1fa60b4STres Popp   } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
18601178654SLei Zhang     if (t.getRank() != 1) {
18701178654SLei Zhang       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
18801178654SLei Zhang       return Type();
18901178654SLei Zhang     }
19001178654SLei Zhang     if (t.getNumElements() > 4) {
19101178654SLei Zhang       parser.emitError(
19201178654SLei Zhang           typeLoc, "vector length has to be less than or equal to 4 but found ")
19301178654SLei Zhang           << t.getNumElements();
19401178654SLei Zhang       return Type();
19501178654SLei Zhang     }
19601178654SLei Zhang   } else {
19701178654SLei Zhang     parser.emitError(typeLoc, "cannot use ")
19801178654SLei Zhang         << type << " to compose SPIR-V types";
19901178654SLei Zhang     return Type();
20001178654SLei Zhang   }
20101178654SLei Zhang 
20201178654SLei Zhang   return type;
20301178654SLei Zhang }
20401178654SLei Zhang 
20501178654SLei Zhang static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
20601178654SLei Zhang                                      DialectAsmParser &parser) {
20701178654SLei Zhang   Type type;
2086842ec42SRiver Riddle   SMLoc typeLoc = parser.getCurrentLocation();
20901178654SLei Zhang   if (parser.parseType(type))
21001178654SLei Zhang     return Type();
21101178654SLei Zhang 
212c1fa60b4STres Popp   if (auto t = llvm::dyn_cast<VectorType>(type)) {
21301178654SLei Zhang     if (t.getRank() != 1) {
21401178654SLei Zhang       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
21501178654SLei Zhang       return Type();
21601178654SLei Zhang     }
21701178654SLei Zhang     if (t.getNumElements() > 4 || t.getNumElements() < 2) {
21801178654SLei Zhang       parser.emitError(typeLoc,
21901178654SLei Zhang                        "matrix columns size has to be less than or equal "
22001178654SLei Zhang                        "to 4 and greater than or equal 2, but found ")
22101178654SLei Zhang           << t.getNumElements();
22201178654SLei Zhang       return Type();
22301178654SLei Zhang     }
22401178654SLei Zhang 
225c1fa60b4STres Popp     if (!llvm::isa<FloatType>(t.getElementType())) {
22601178654SLei Zhang       parser.emitError(typeLoc, "matrix columns' elements must be of "
22701178654SLei Zhang                                 "Float type, got ")
22801178654SLei Zhang           << t.getElementType();
22901178654SLei Zhang       return Type();
23001178654SLei Zhang     }
23101178654SLei Zhang   } else {
23201178654SLei Zhang     parser.emitError(typeLoc, "matrix must be composed using vector "
23301178654SLei Zhang                               "type, got ")
23401178654SLei Zhang         << type;
23501178654SLei Zhang     return Type();
23601178654SLei Zhang   }
23701178654SLei Zhang 
23801178654SLei Zhang   return type;
23901178654SLei Zhang }
24001178654SLei Zhang 
2412ef24139SWeiwei Li static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
2422ef24139SWeiwei Li                                            DialectAsmParser &parser) {
2432ef24139SWeiwei Li   Type type;
2446842ec42SRiver Riddle   SMLoc typeLoc = parser.getCurrentLocation();
2452ef24139SWeiwei Li   if (parser.parseType(type))
2462ef24139SWeiwei Li     return Type();
2472ef24139SWeiwei Li 
248c1fa60b4STres Popp   if (!llvm::isa<ImageType>(type)) {
2492ef24139SWeiwei Li     parser.emitError(typeLoc,
2502ef24139SWeiwei Li                      "sampled image must be composed using image type, got ")
2512ef24139SWeiwei Li         << type;
2522ef24139SWeiwei Li     return Type();
2532ef24139SWeiwei Li   }
2542ef24139SWeiwei Li 
2552ef24139SWeiwei Li   return type;
2562ef24139SWeiwei Li }
2572ef24139SWeiwei Li 
25801178654SLei Zhang /// Parses an optional `, stride = N` assembly segment. If no parsing failure
25901178654SLei Zhang /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
26001178654SLei Zhang /// missing.
26101178654SLei Zhang static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
26201178654SLei Zhang                                               DialectAsmParser &parser,
26301178654SLei Zhang                                               unsigned &stride) {
26401178654SLei Zhang   if (failed(parser.parseOptionalComma())) {
26501178654SLei Zhang     stride = 0;
26601178654SLei Zhang     return success();
26701178654SLei Zhang   }
26801178654SLei Zhang 
26901178654SLei Zhang   if (parser.parseKeyword("stride") || parser.parseEqual())
27001178654SLei Zhang     return failure();
27101178654SLei Zhang 
2726842ec42SRiver Riddle   SMLoc strideLoc = parser.getCurrentLocation();
27322426110SRamkumar Ramachandra   std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
27401178654SLei Zhang   if (!optStride)
27501178654SLei Zhang     return failure();
27601178654SLei Zhang 
2776d5fc1e3SKazu Hirata   if (!(stride = *optStride)) {
27801178654SLei Zhang     parser.emitError(strideLoc, "ArrayStride must be greater than zero");
27901178654SLei Zhang     return failure();
28001178654SLei Zhang   }
28101178654SLei Zhang   return success();
28201178654SLei Zhang }
28301178654SLei Zhang 
28401178654SLei Zhang // element-type ::= integer-type
28501178654SLei Zhang //                | floating-point-type
28601178654SLei Zhang //                | vector-type
28701178654SLei Zhang //                | spirv-type
28801178654SLei Zhang //
2895ab6ef75SJakub Kuderski // array-type ::= `!spirv.array` `<` integer-literal `x` element-type
29001178654SLei Zhang //                (`,` `stride` `=` integer-literal)? `>`
29101178654SLei Zhang static Type parseArrayType(SPIRVDialect const &dialect,
29201178654SLei Zhang                            DialectAsmParser &parser) {
29301178654SLei Zhang   if (parser.parseLess())
29401178654SLei Zhang     return Type();
29501178654SLei Zhang 
29601178654SLei Zhang   SmallVector<int64_t, 1> countDims;
2976842ec42SRiver Riddle   SMLoc countLoc = parser.getCurrentLocation();
29801178654SLei Zhang   if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
29901178654SLei Zhang     return Type();
30001178654SLei Zhang   if (countDims.size() != 1) {
30101178654SLei Zhang     parser.emitError(countLoc,
30201178654SLei Zhang                      "expected single integer for array element count");
30301178654SLei Zhang     return Type();
30401178654SLei Zhang   }
30501178654SLei Zhang 
30601178654SLei Zhang   // According to the SPIR-V spec:
30701178654SLei Zhang   // "Length is the number of elements in the array. It must be at least 1."
30801178654SLei Zhang   int64_t count = countDims[0];
30901178654SLei Zhang   if (count == 0) {
31001178654SLei Zhang     parser.emitError(countLoc, "expected array length greater than 0");
31101178654SLei Zhang     return Type();
31201178654SLei Zhang   }
31301178654SLei Zhang 
31401178654SLei Zhang   Type elementType = parseAndVerifyType(dialect, parser);
31501178654SLei Zhang   if (!elementType)
31601178654SLei Zhang     return Type();
31701178654SLei Zhang 
31801178654SLei Zhang   unsigned stride = 0;
31901178654SLei Zhang   if (failed(parseOptionalArrayStride(dialect, parser, stride)))
32001178654SLei Zhang     return Type();
32101178654SLei Zhang 
32201178654SLei Zhang   if (parser.parseGreater())
32301178654SLei Zhang     return Type();
32401178654SLei Zhang   return ArrayType::get(elementType, count, stride);
32501178654SLei Zhang }
32601178654SLei Zhang 
3271d515978SJakub Kuderski // cooperative-matrix-type ::=
3284ba61f5aSJakub Kuderski //   `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
3294ba61f5aSJakub Kuderski //                           scope `,` use `>`
33001178654SLei Zhang static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
33101178654SLei Zhang                                        DialectAsmParser &parser) {
33201178654SLei Zhang   if (parser.parseLess())
3334ba61f5aSJakub Kuderski     return {};
3344ba61f5aSJakub Kuderski 
3354ba61f5aSJakub Kuderski   SmallVector<int64_t, 2> dims;
3364ba61f5aSJakub Kuderski   SMLoc countLoc = parser.getCurrentLocation();
3374ba61f5aSJakub Kuderski   if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
3384ba61f5aSJakub Kuderski     return {};
3394ba61f5aSJakub Kuderski 
3404ba61f5aSJakub Kuderski   if (dims.size() != 2) {
3414ba61f5aSJakub Kuderski     parser.emitError(countLoc, "expected row and column count");
3424ba61f5aSJakub Kuderski     return {};
3434ba61f5aSJakub Kuderski   }
3444ba61f5aSJakub Kuderski 
3454ba61f5aSJakub Kuderski   auto elementTy = parseAndVerifyType(dialect, parser);
3464ba61f5aSJakub Kuderski   if (!elementTy)
3474ba61f5aSJakub Kuderski     return {};
3484ba61f5aSJakub Kuderski 
3494ba61f5aSJakub Kuderski   Scope scope;
3509415241cSJakub Kuderski   if (parser.parseComma() ||
3519415241cSJakub Kuderski       spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
3524ba61f5aSJakub Kuderski     return {};
3534ba61f5aSJakub Kuderski 
3544ba61f5aSJakub Kuderski   CooperativeMatrixUseKHR use;
3559415241cSJakub Kuderski   if (parser.parseComma() ||
3569415241cSJakub Kuderski       spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
3574ba61f5aSJakub Kuderski     return {};
3584ba61f5aSJakub Kuderski 
3594ba61f5aSJakub Kuderski   if (parser.parseGreater())
3604ba61f5aSJakub Kuderski     return {};
3614ba61f5aSJakub Kuderski 
3624ba61f5aSJakub Kuderski   return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
3634ba61f5aSJakub Kuderski }
3644ba61f5aSJakub Kuderski 
36501178654SLei Zhang // TODO: Reorder methods to be utilities first and parse*Type
36601178654SLei Zhang // methods in alphabetical order
36701178654SLei Zhang //
36801178654SLei Zhang // storage-class ::= `UniformConstant`
36901178654SLei Zhang //                 | `Uniform`
37001178654SLei Zhang //                 | `Workgroup`
37101178654SLei Zhang //                 | <and other storage classes...>
37201178654SLei Zhang //
3735ab6ef75SJakub Kuderski // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
37401178654SLei Zhang static Type parsePointerType(SPIRVDialect const &dialect,
37501178654SLei Zhang                              DialectAsmParser &parser) {
37601178654SLei Zhang   if (parser.parseLess())
37701178654SLei Zhang     return Type();
37801178654SLei Zhang 
37901178654SLei Zhang   auto pointeeType = parseAndVerifyType(dialect, parser);
38001178654SLei Zhang   if (!pointeeType)
38101178654SLei Zhang     return Type();
38201178654SLei Zhang 
38301178654SLei Zhang   StringRef storageClassSpec;
3846842ec42SRiver Riddle   SMLoc storageClassLoc = parser.getCurrentLocation();
38501178654SLei Zhang   if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
38601178654SLei Zhang     return Type();
38701178654SLei Zhang 
38801178654SLei Zhang   auto storageClass = symbolizeStorageClass(storageClassSpec);
38901178654SLei Zhang   if (!storageClass) {
39001178654SLei Zhang     parser.emitError(storageClassLoc, "unknown storage class: ")
39101178654SLei Zhang         << storageClassSpec;
39201178654SLei Zhang     return Type();
39301178654SLei Zhang   }
39401178654SLei Zhang   if (parser.parseGreater())
39501178654SLei Zhang     return Type();
39601178654SLei Zhang   return PointerType::get(pointeeType, *storageClass);
39701178654SLei Zhang }
39801178654SLei Zhang 
3995ab6ef75SJakub Kuderski // runtime-array-type ::= `!spirv.rtarray` `<` element-type
40001178654SLei Zhang //                        (`,` `stride` `=` integer-literal)? `>`
40101178654SLei Zhang static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
40201178654SLei Zhang                                   DialectAsmParser &parser) {
40301178654SLei Zhang   if (parser.parseLess())
40401178654SLei Zhang     return Type();
40501178654SLei Zhang 
40601178654SLei Zhang   Type elementType = parseAndVerifyType(dialect, parser);
40701178654SLei Zhang   if (!elementType)
40801178654SLei Zhang     return Type();
40901178654SLei Zhang 
41001178654SLei Zhang   unsigned stride = 0;
41101178654SLei Zhang   if (failed(parseOptionalArrayStride(dialect, parser, stride)))
41201178654SLei Zhang     return Type();
41301178654SLei Zhang 
41401178654SLei Zhang   if (parser.parseGreater())
41501178654SLei Zhang     return Type();
41601178654SLei Zhang   return RuntimeArrayType::get(elementType, stride);
41701178654SLei Zhang }
41801178654SLei Zhang 
4195ab6ef75SJakub Kuderski // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
42001178654SLei Zhang static Type parseMatrixType(SPIRVDialect const &dialect,
42101178654SLei Zhang                             DialectAsmParser &parser) {
42201178654SLei Zhang   if (parser.parseLess())
42301178654SLei Zhang     return Type();
42401178654SLei Zhang 
42501178654SLei Zhang   SmallVector<int64_t, 1> countDims;
4266842ec42SRiver Riddle   SMLoc countLoc = parser.getCurrentLocation();
42701178654SLei Zhang   if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
42801178654SLei Zhang     return Type();
42901178654SLei Zhang   if (countDims.size() != 1) {
43001178654SLei Zhang     parser.emitError(countLoc, "expected single unsigned "
43101178654SLei Zhang                                "integer for number of columns");
43201178654SLei Zhang     return Type();
43301178654SLei Zhang   }
43401178654SLei Zhang 
43501178654SLei Zhang   int64_t columnCount = countDims[0];
43601178654SLei Zhang   // According to the specification, Matrices can have 2, 3, or 4 columns
43701178654SLei Zhang   if (columnCount < 2 || columnCount > 4) {
43801178654SLei Zhang     parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
43901178654SLei Zhang                                "columns");
44001178654SLei Zhang     return Type();
44101178654SLei Zhang   }
44201178654SLei Zhang 
44301178654SLei Zhang   Type columnType = parseAndVerifyMatrixType(dialect, parser);
44401178654SLei Zhang   if (!columnType)
44501178654SLei Zhang     return Type();
44601178654SLei Zhang 
44701178654SLei Zhang   if (parser.parseGreater())
44801178654SLei Zhang     return Type();
44901178654SLei Zhang 
45001178654SLei Zhang   return MatrixType::get(columnType, columnCount);
45101178654SLei Zhang }
45201178654SLei Zhang 
45301178654SLei Zhang // Specialize this function to parse each of the parameters that define an
45401178654SLei Zhang // ImageType. By default it assumes this is an enum type.
45501178654SLei Zhang template <typename ValTy>
45622426110SRamkumar Ramachandra static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
45701178654SLei Zhang                                            DialectAsmParser &parser) {
45801178654SLei Zhang   StringRef enumSpec;
4596842ec42SRiver Riddle   SMLoc enumLoc = parser.getCurrentLocation();
46001178654SLei Zhang   if (parser.parseKeyword(&enumSpec)) {
4611a36588eSKazu Hirata     return std::nullopt;
46201178654SLei Zhang   }
46301178654SLei Zhang 
46401178654SLei Zhang   auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
46501178654SLei Zhang   if (!val)
46601178654SLei Zhang     parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
46701178654SLei Zhang   return val;
46801178654SLei Zhang }
46901178654SLei Zhang 
47001178654SLei Zhang template <>
47122426110SRamkumar Ramachandra std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
47201178654SLei Zhang                                          DialectAsmParser &parser) {
47301178654SLei Zhang   // TODO: Further verify that the element type can be sampled
47401178654SLei Zhang   auto ty = parseAndVerifyType(dialect, parser);
47501178654SLei Zhang   if (!ty)
4761a36588eSKazu Hirata     return std::nullopt;
47701178654SLei Zhang   return ty;
47801178654SLei Zhang }
47901178654SLei Zhang 
48001178654SLei Zhang template <typename IntTy>
48122426110SRamkumar Ramachandra static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
48201178654SLei Zhang                                                   DialectAsmParser &parser) {
48301178654SLei Zhang   IntTy offsetVal = std::numeric_limits<IntTy>::max();
48401178654SLei Zhang   if (parser.parseInteger(offsetVal))
4851a36588eSKazu Hirata     return std::nullopt;
48601178654SLei Zhang   return offsetVal;
48701178654SLei Zhang }
48801178654SLei Zhang 
48901178654SLei Zhang template <>
49022426110SRamkumar Ramachandra std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
49101178654SLei Zhang                                                  DialectAsmParser &parser) {
49201178654SLei Zhang   return parseAndVerifyInteger<unsigned>(dialect, parser);
49301178654SLei Zhang }
49401178654SLei Zhang 
49501178654SLei Zhang namespace {
49601178654SLei Zhang // Functor object to parse a comma separated list of specs. The function
49701178654SLei Zhang // parseAndVerify does the actual parsing and verification of individual
49801178654SLei Zhang // elements. This is a functor since parsing the last element of the list
49901178654SLei Zhang // (termination condition) needs partial specialization.
500b7f93c28SJeff Niu template <typename ParseType, typename... Args>
501b7f93c28SJeff Niu struct ParseCommaSeparatedList {
50222426110SRamkumar Ramachandra   std::optional<std::tuple<ParseType, Args...>>
50301178654SLei Zhang   operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
50401178654SLei Zhang     auto parseVal = parseAndVerify<ParseType>(dialect, parser);
50501178654SLei Zhang     if (!parseVal)
5061a36588eSKazu Hirata       return std::nullopt;
50701178654SLei Zhang 
50801178654SLei Zhang     auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
50901178654SLei Zhang     if (numArgs != 0 && failed(parser.parseComma()))
5101a36588eSKazu Hirata       return std::nullopt;
51101178654SLei Zhang     auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
51201178654SLei Zhang     if (!remainingValues)
5131a36588eSKazu Hirata       return std::nullopt;
514c27d8152SKazu Hirata     return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
515c27d8152SKazu Hirata                           remainingValues.value());
51601178654SLei Zhang   }
51701178654SLei Zhang };
51801178654SLei Zhang 
51901178654SLei Zhang // Partial specialization of the function to parse a comma separated list of
52001178654SLei Zhang // specs to parse the last element of the list.
521b7f93c28SJeff Niu template <typename ParseType>
522b7f93c28SJeff Niu struct ParseCommaSeparatedList<ParseType> {
52322426110SRamkumar Ramachandra   std::optional<std::tuple<ParseType>>
52422426110SRamkumar Ramachandra   operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
52501178654SLei Zhang     if (auto value = parseAndVerify<ParseType>(dialect, parser))
5266d5fc1e3SKazu Hirata       return std::tuple<ParseType>(*value);
5271a36588eSKazu Hirata     return std::nullopt;
52801178654SLei Zhang   }
52901178654SLei Zhang };
53001178654SLei Zhang } // namespace
53101178654SLei Zhang 
53201178654SLei Zhang // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
53301178654SLei Zhang //
53401178654SLei Zhang // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
53501178654SLei Zhang //
53601178654SLei Zhang // arrayed-info ::= `NonArrayed` | `Arrayed`
53701178654SLei Zhang //
53801178654SLei Zhang // sampling-info ::= `SingleSampled` | `MultiSampled`
53901178654SLei Zhang //
54001178654SLei Zhang // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` |  `NoSampler`
54101178654SLei Zhang //
54201178654SLei Zhang // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
54301178654SLei Zhang //
5445ab6ef75SJakub Kuderski // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
54501178654SLei Zhang //                              arrayed-info `,` sampling-info `,`
54601178654SLei Zhang //                              sampler-use-info `,` format `>`
54701178654SLei Zhang static Type parseImageType(SPIRVDialect const &dialect,
54801178654SLei Zhang                            DialectAsmParser &parser) {
54901178654SLei Zhang   if (parser.parseLess())
55001178654SLei Zhang     return Type();
55101178654SLei Zhang 
55201178654SLei Zhang   auto value =
55301178654SLei Zhang       ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
55401178654SLei Zhang                               ImageSamplingInfo, ImageSamplerUseInfo,
55501178654SLei Zhang                               ImageFormat>{}(dialect, parser);
55601178654SLei Zhang   if (!value)
55701178654SLei Zhang     return Type();
55801178654SLei Zhang 
55901178654SLei Zhang   if (parser.parseGreater())
56001178654SLei Zhang     return Type();
5616d5fc1e3SKazu Hirata   return ImageType::get(*value);
56201178654SLei Zhang }
56301178654SLei Zhang 
5645ab6ef75SJakub Kuderski // sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
5652ef24139SWeiwei Li static Type parseSampledImageType(SPIRVDialect const &dialect,
5662ef24139SWeiwei Li                                   DialectAsmParser &parser) {
5672ef24139SWeiwei Li   if (parser.parseLess())
5682ef24139SWeiwei Li     return Type();
5692ef24139SWeiwei Li 
5702ef24139SWeiwei Li   Type parsedType = parseAndVerifySampledImageType(dialect, parser);
5712ef24139SWeiwei Li   if (!parsedType)
5722ef24139SWeiwei Li     return Type();
5732ef24139SWeiwei Li 
5742ef24139SWeiwei Li   if (parser.parseGreater())
5752ef24139SWeiwei Li     return Type();
5762ef24139SWeiwei Li   return SampledImageType::get(parsedType);
5772ef24139SWeiwei Li }
5782ef24139SWeiwei Li 
57901178654SLei Zhang // Parse decorations associated with a member.
58001178654SLei Zhang static ParseResult parseStructMemberDecorations(
58101178654SLei Zhang     SPIRVDialect const &dialect, DialectAsmParser &parser,
58201178654SLei Zhang     ArrayRef<Type> memberTypes,
58301178654SLei Zhang     SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
58401178654SLei Zhang     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
58501178654SLei Zhang 
58601178654SLei Zhang   // Check if the first element is offset.
5876842ec42SRiver Riddle   SMLoc offsetLoc = parser.getCurrentLocation();
58801178654SLei Zhang   StructType::OffsetInfo offset = 0;
58901178654SLei Zhang   OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
5909750648cSKazu Hirata   if (offsetParseResult.has_value()) {
59101178654SLei Zhang     if (failed(*offsetParseResult))
59201178654SLei Zhang       return failure();
59301178654SLei Zhang 
59401178654SLei Zhang     if (offsetInfo.size() != memberTypes.size() - 1) {
59501178654SLei Zhang       return parser.emitError(offsetLoc,
59601178654SLei Zhang                               "offset specification must be given for "
59701178654SLei Zhang                               "all members");
59801178654SLei Zhang     }
59901178654SLei Zhang     offsetInfo.push_back(offset);
60001178654SLei Zhang   }
60101178654SLei Zhang 
60201178654SLei Zhang   // Check for no spirv::Decorations.
60301178654SLei Zhang   if (succeeded(parser.parseOptionalRSquare()))
60401178654SLei Zhang     return success();
60501178654SLei Zhang 
60601178654SLei Zhang   // If there was an offset, make sure to parse the comma.
6079750648cSKazu Hirata   if (offsetParseResult.has_value() && parser.parseComma())
60801178654SLei Zhang     return failure();
60901178654SLei Zhang 
61001178654SLei Zhang   // Check for spirv::Decorations.
611167bbfcbSJakub Tucholski   auto parseDecorations = [&]() {
61201178654SLei Zhang     auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
61301178654SLei Zhang     if (!memberDecoration)
61401178654SLei Zhang       return failure();
61501178654SLei Zhang 
61601178654SLei Zhang     // Parse member decoration value if it exists.
61701178654SLei Zhang     if (succeeded(parser.parseOptionalEqual())) {
61801178654SLei Zhang       auto memberDecorationValue =
61901178654SLei Zhang           parseAndVerifyInteger<uint32_t>(dialect, parser);
62001178654SLei Zhang 
62101178654SLei Zhang       if (!memberDecorationValue)
62201178654SLei Zhang         return failure();
62301178654SLei Zhang 
62401178654SLei Zhang       memberDecorationInfo.emplace_back(
62501178654SLei Zhang           static_cast<uint32_t>(memberTypes.size() - 1), 1,
626c27d8152SKazu Hirata           memberDecoration.value(), memberDecorationValue.value());
62701178654SLei Zhang     } else {
62801178654SLei Zhang       memberDecorationInfo.emplace_back(
62901178654SLei Zhang           static_cast<uint32_t>(memberTypes.size() - 1), 0,
630c27d8152SKazu Hirata           memberDecoration.value(), 0);
63101178654SLei Zhang     }
632167bbfcbSJakub Tucholski     return success();
633167bbfcbSJakub Tucholski   };
634167bbfcbSJakub Tucholski   if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
635167bbfcbSJakub Tucholski       failed(parser.parseRSquare()))
636167bbfcbSJakub Tucholski     return failure();
63701178654SLei Zhang 
638167bbfcbSJakub Tucholski   return success();
63901178654SLei Zhang }
64001178654SLei Zhang 
64101178654SLei Zhang // struct-member-decoration ::= integer-literal? spirv-decoration*
64201178654SLei Zhang // struct-type ::=
6435ab6ef75SJakub Kuderski //             `!spirv.struct<` (id `,`)?
64401178654SLei Zhang //                          `(`
64501178654SLei Zhang //                            (spirv-type (`[` struct-member-decoration `]`)?)*
64601178654SLei Zhang //                          `)>`
64701178654SLei Zhang static Type parseStructType(SPIRVDialect const &dialect,
64801178654SLei Zhang                             DialectAsmParser &parser) {
64901178654SLei Zhang   // TODO: This function is quite lengthy. Break it down into smaller chunks.
65001178654SLei Zhang 
65101178654SLei Zhang   if (parser.parseLess())
65201178654SLei Zhang     return Type();
65301178654SLei Zhang 
65401178654SLei Zhang   StringRef identifier;
655b121c266SMarkus Böck   FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
65601178654SLei Zhang 
65701178654SLei Zhang   // Check if this is an identified struct type.
65801178654SLei Zhang   if (succeeded(parser.parseOptionalKeyword(&identifier))) {
65901178654SLei Zhang     // Check if this is a possible recursive reference.
660b121c266SMarkus Böck     auto structType =
661b121c266SMarkus Böck         StructType::getIdentified(dialect.getContext(), identifier);
662b121c266SMarkus Böck     cyclicParse = parser.tryStartCyclicParse(structType);
66301178654SLei Zhang     if (succeeded(parser.parseOptionalGreater())) {
664b121c266SMarkus Böck       if (succeeded(cyclicParse)) {
66501178654SLei Zhang         parser.emitError(
66601178654SLei Zhang             parser.getNameLoc(),
66701178654SLei Zhang             "recursive struct reference not nested in struct definition");
66801178654SLei Zhang 
66901178654SLei Zhang         return Type();
67001178654SLei Zhang       }
67101178654SLei Zhang 
672b121c266SMarkus Böck       return structType;
67301178654SLei Zhang     }
67401178654SLei Zhang 
67501178654SLei Zhang     if (failed(parser.parseComma()))
67601178654SLei Zhang       return Type();
67701178654SLei Zhang 
678b121c266SMarkus Böck     if (failed(cyclicParse)) {
67901178654SLei Zhang       parser.emitError(parser.getNameLoc(),
68001178654SLei Zhang                        "identifier already used for an enclosing struct");
681b121c266SMarkus Böck       return Type();
68201178654SLei Zhang     }
68301178654SLei Zhang   }
68401178654SLei Zhang 
68501178654SLei Zhang   if (failed(parser.parseLParen()))
686b121c266SMarkus Böck     return Type();
68701178654SLei Zhang 
68801178654SLei Zhang   if (succeeded(parser.parseOptionalRParen()) &&
68901178654SLei Zhang       succeeded(parser.parseOptionalGreater())) {
69001178654SLei Zhang     return StructType::getEmpty(dialect.getContext(), identifier);
69101178654SLei Zhang   }
69201178654SLei Zhang 
69301178654SLei Zhang   StructType idStructTy;
69401178654SLei Zhang 
69501178654SLei Zhang   if (!identifier.empty())
69601178654SLei Zhang     idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
69701178654SLei Zhang 
69801178654SLei Zhang   SmallVector<Type, 4> memberTypes;
69901178654SLei Zhang   SmallVector<StructType::OffsetInfo, 4> offsetInfo;
70001178654SLei Zhang   SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
70101178654SLei Zhang 
70201178654SLei Zhang   do {
70301178654SLei Zhang     Type memberType;
70401178654SLei Zhang     if (parser.parseType(memberType))
705b121c266SMarkus Böck       return Type();
70601178654SLei Zhang     memberTypes.push_back(memberType);
70701178654SLei Zhang 
70801178654SLei Zhang     if (succeeded(parser.parseOptionalLSquare()))
70901178654SLei Zhang       if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
71001178654SLei Zhang                                        memberDecorationInfo))
711b121c266SMarkus Böck         return Type();
71201178654SLei Zhang   } while (succeeded(parser.parseOptionalComma()));
71301178654SLei Zhang 
71401178654SLei Zhang   if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
71501178654SLei Zhang     parser.emitError(parser.getNameLoc(),
71601178654SLei Zhang                      "offset specification must be given for all members");
717b121c266SMarkus Böck     return Type();
71801178654SLei Zhang   }
71901178654SLei Zhang 
72001178654SLei Zhang   if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
721b121c266SMarkus Böck     return Type();
72201178654SLei Zhang 
72301178654SLei Zhang   if (!identifier.empty()) {
72401178654SLei Zhang     if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
72501178654SLei Zhang                                      memberDecorationInfo)))
72601178654SLei Zhang       return Type();
72701178654SLei Zhang     return idStructTy;
72801178654SLei Zhang   }
72901178654SLei Zhang 
73001178654SLei Zhang   return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
73101178654SLei Zhang }
73201178654SLei Zhang 
73301178654SLei Zhang // spirv-type ::= array-type
73401178654SLei Zhang //              | element-type
73501178654SLei Zhang //              | image-type
73601178654SLei Zhang //              | pointer-type
73701178654SLei Zhang //              | runtime-array-type
7382ef24139SWeiwei Li //              | sampled-image-type
73901178654SLei Zhang //              | struct-type
74001178654SLei Zhang Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
74101178654SLei Zhang   StringRef keyword;
74201178654SLei Zhang   if (parser.parseKeyword(&keyword))
74301178654SLei Zhang     return Type();
74401178654SLei Zhang 
74501178654SLei Zhang   if (keyword == "array")
74601178654SLei Zhang     return parseArrayType(*this, parser);
7474ba61f5aSJakub Kuderski   if (keyword == "coopmatrix")
74801178654SLei Zhang     return parseCooperativeMatrixType(*this, parser);
74901178654SLei Zhang   if (keyword == "image")
75001178654SLei Zhang     return parseImageType(*this, parser);
75101178654SLei Zhang   if (keyword == "ptr")
75201178654SLei Zhang     return parsePointerType(*this, parser);
75301178654SLei Zhang   if (keyword == "rtarray")
75401178654SLei Zhang     return parseRuntimeArrayType(*this, parser);
7552ef24139SWeiwei Li   if (keyword == "sampled_image")
7562ef24139SWeiwei Li     return parseSampledImageType(*this, parser);
75701178654SLei Zhang   if (keyword == "struct")
75801178654SLei Zhang     return parseStructType(*this, parser);
75901178654SLei Zhang   if (keyword == "matrix")
76001178654SLei Zhang     return parseMatrixType(*this, parser);
76101178654SLei Zhang   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
76201178654SLei Zhang   return Type();
76301178654SLei Zhang }
76401178654SLei Zhang 
76501178654SLei Zhang //===----------------------------------------------------------------------===//
76601178654SLei Zhang // Type Printing
76701178654SLei Zhang //===----------------------------------------------------------------------===//
76801178654SLei Zhang 
76901178654SLei Zhang static void print(ArrayType type, DialectAsmPrinter &os) {
77001178654SLei Zhang   os << "array<" << type.getNumElements() << " x " << type.getElementType();
77101178654SLei Zhang   if (unsigned stride = type.getArrayStride())
77201178654SLei Zhang     os << ", stride=" << stride;
77301178654SLei Zhang   os << ">";
77401178654SLei Zhang }
77501178654SLei Zhang 
77601178654SLei Zhang static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
77701178654SLei Zhang   os << "rtarray<" << type.getElementType();
77801178654SLei Zhang   if (unsigned stride = type.getArrayStride())
77901178654SLei Zhang     os << ", stride=" << stride;
78001178654SLei Zhang   os << ">";
78101178654SLei Zhang }
78201178654SLei Zhang 
78301178654SLei Zhang static void print(PointerType type, DialectAsmPrinter &os) {
78401178654SLei Zhang   os << "ptr<" << type.getPointeeType() << ", "
78501178654SLei Zhang      << stringifyStorageClass(type.getStorageClass()) << ">";
78601178654SLei Zhang }
78701178654SLei Zhang 
78801178654SLei Zhang static void print(ImageType type, DialectAsmPrinter &os) {
78901178654SLei Zhang   os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
79001178654SLei Zhang      << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
79101178654SLei Zhang      << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
79201178654SLei Zhang      << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
79301178654SLei Zhang      << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
79401178654SLei Zhang      << stringifyImageFormat(type.getImageFormat()) << ">";
79501178654SLei Zhang }
79601178654SLei Zhang 
7972ef24139SWeiwei Li static void print(SampledImageType type, DialectAsmPrinter &os) {
7982ef24139SWeiwei Li   os << "sampled_image<" << type.getImageType() << ">";
7992ef24139SWeiwei Li }
8002ef24139SWeiwei Li 
80101178654SLei Zhang static void print(StructType type, DialectAsmPrinter &os) {
802b121c266SMarkus Böck   FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
80301178654SLei Zhang 
80401178654SLei Zhang   os << "struct<";
80501178654SLei Zhang 
80601178654SLei Zhang   if (type.isIdentified()) {
80701178654SLei Zhang     os << type.getIdentifier();
80801178654SLei Zhang 
809b121c266SMarkus Böck     cyclicPrint = os.tryStartCyclicPrint(type);
810b121c266SMarkus Böck     if (failed(cyclicPrint)) {
81101178654SLei Zhang       os << ">";
81201178654SLei Zhang       return;
81301178654SLei Zhang     }
81401178654SLei Zhang 
81501178654SLei Zhang     os << ", ";
81601178654SLei Zhang   }
81701178654SLei Zhang 
81801178654SLei Zhang   os << "(";
81901178654SLei Zhang 
82001178654SLei Zhang   auto printMember = [&](unsigned i) {
82101178654SLei Zhang     os << type.getElementType(i);
82201178654SLei Zhang     SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
82301178654SLei Zhang     type.getMemberDecorations(i, decorations);
82401178654SLei Zhang     if (type.hasOffset() || !decorations.empty()) {
82501178654SLei Zhang       os << " [";
82601178654SLei Zhang       if (type.hasOffset()) {
82701178654SLei Zhang         os << type.getMemberOffset(i);
82801178654SLei Zhang         if (!decorations.empty())
82901178654SLei Zhang           os << ", ";
83001178654SLei Zhang       }
83101178654SLei Zhang       auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
83201178654SLei Zhang         os << stringifyDecoration(decoration.decoration);
83301178654SLei Zhang         if (decoration.hasValue) {
83401178654SLei Zhang           os << "=" << decoration.decorationValue;
83501178654SLei Zhang         }
83601178654SLei Zhang       };
83701178654SLei Zhang       llvm::interleaveComma(decorations, os, eachFn);
83801178654SLei Zhang       os << "]";
83901178654SLei Zhang     }
84001178654SLei Zhang   };
84101178654SLei Zhang   llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
84201178654SLei Zhang                         printMember);
84301178654SLei Zhang   os << ")>";
84401178654SLei Zhang }
84501178654SLei Zhang 
8464ba61f5aSJakub Kuderski static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
8474ba61f5aSJakub Kuderski   os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
8484ba61f5aSJakub Kuderski      << type.getElementType() << ", " << type.getScope() << ", "
8494ba61f5aSJakub Kuderski      << type.getUse() << ">";
8504ba61f5aSJakub Kuderski }
8514ba61f5aSJakub Kuderski 
85201178654SLei Zhang static void print(MatrixType type, DialectAsmPrinter &os) {
85301178654SLei Zhang   os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
85401178654SLei Zhang   os << ">";
85501178654SLei Zhang }
85601178654SLei Zhang 
85701178654SLei Zhang void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
85801178654SLei Zhang   TypeSwitch<Type>(type)
859*b719ab4eSAndrea Faulds       .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
860*b719ab4eSAndrea Faulds             ImageType, SampledImageType, StructType, MatrixType>(
861*b719ab4eSAndrea Faulds           [&](auto type) { print(type, os); })
86201178654SLei Zhang       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
86301178654SLei Zhang }
86401178654SLei Zhang 
86501178654SLei Zhang //===----------------------------------------------------------------------===//
86601178654SLei Zhang // Constant
86701178654SLei Zhang //===----------------------------------------------------------------------===//
86801178654SLei Zhang 
86901178654SLei Zhang Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
87001178654SLei Zhang                                              Attribute value, Type type,
87101178654SLei Zhang                                              Location loc) {
8725dce7481SIvan Butygin   if (auto poison = dyn_cast<ub::PoisonAttr>(value))
8735dce7481SIvan Butygin     return builder.create<ub::PoisonOp>(loc, type, poison);
8745dce7481SIvan Butygin 
87501178654SLei Zhang   if (!spirv::ConstantOp::isBuildableWith(type))
87601178654SLei Zhang     return nullptr;
87701178654SLei Zhang 
87801178654SLei Zhang   return builder.create<spirv::ConstantOp>(loc, type, value);
87901178654SLei Zhang }
88001178654SLei Zhang 
88101178654SLei Zhang //===----------------------------------------------------------------------===//
88201178654SLei Zhang // Shader Interface ABI
88301178654SLei Zhang //===----------------------------------------------------------------------===//
88401178654SLei Zhang 
88501178654SLei Zhang LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
88601178654SLei Zhang                                                      NamedAttribute attribute) {
8870c7890c8SRiver Riddle   StringRef symbol = attribute.getName().strref();
8880c7890c8SRiver Riddle   Attribute attr = attribute.getValue();
88901178654SLei Zhang 
89001178654SLei Zhang   if (symbol == spirv::getEntryPointABIAttrName()) {
891c1fa60b4STres Popp     if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
89201178654SLei Zhang       return op->emitError("'")
893a31ff0afSMogball              << symbol << "' attribute must be an entry point ABI attribute";
894a31ff0afSMogball     }
89501178654SLei Zhang   } else if (symbol == spirv::getTargetEnvAttrName()) {
896c1fa60b4STres Popp     if (!llvm::isa<spirv::TargetEnvAttr>(attr))
89701178654SLei Zhang       return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
89801178654SLei Zhang   } else {
89901178654SLei Zhang     return op->emitError("found unsupported '")
90001178654SLei Zhang            << symbol << "' attribute on operation";
90101178654SLei Zhang   }
90201178654SLei Zhang 
90301178654SLei Zhang   return success();
90401178654SLei Zhang }
90501178654SLei Zhang 
90601178654SLei Zhang /// Verifies the given SPIR-V `attribute` attached to a value of the given
90701178654SLei Zhang /// `valueType` is valid.
90801178654SLei Zhang static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
90901178654SLei Zhang                                            NamedAttribute attribute) {
9100c7890c8SRiver Riddle   StringRef symbol = attribute.getName().strref();
9110c7890c8SRiver Riddle   Attribute attr = attribute.getValue();
91201178654SLei Zhang 
913747d8fb0SKohei Yamaguchi   if (symbol == spirv::getInterfaceVarABIAttrName()) {
914c1fa60b4STres Popp     auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
91501178654SLei Zhang     if (!varABIAttr)
91601178654SLei Zhang       return emitError(loc, "'")
91701178654SLei Zhang              << symbol << "' must be a spirv::InterfaceVarABIAttr";
91801178654SLei Zhang 
91901178654SLei Zhang     if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
92001178654SLei Zhang       return emitError(loc, "'") << symbol
92101178654SLei Zhang                                  << "' attribute cannot specify storage class "
92201178654SLei Zhang                                     "when attaching to a non-scalar value";
92301178654SLei Zhang     return success();
92401178654SLei Zhang   }
925747d8fb0SKohei Yamaguchi   if (symbol == spirv::DecorationAttr::name) {
926747d8fb0SKohei Yamaguchi     if (!isa<spirv::DecorationAttr>(attr))
927747d8fb0SKohei Yamaguchi       return emitError(loc, "'")
928747d8fb0SKohei Yamaguchi              << symbol << "' must be a spirv::DecorationAttr";
929747d8fb0SKohei Yamaguchi     return success();
930747d8fb0SKohei Yamaguchi   }
931747d8fb0SKohei Yamaguchi 
932747d8fb0SKohei Yamaguchi   return emitError(loc, "found unsupported '")
933747d8fb0SKohei Yamaguchi          << symbol << "' attribute on region argument";
934747d8fb0SKohei Yamaguchi }
93501178654SLei Zhang 
93601178654SLei Zhang LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
93701178654SLei Zhang                                                      unsigned regionIndex,
93801178654SLei Zhang                                                      unsigned argIndex,
93901178654SLei Zhang                                                      NamedAttribute attribute) {
940747d8fb0SKohei Yamaguchi   auto funcOp = dyn_cast<FunctionOpInterface>(op);
941747d8fb0SKohei Yamaguchi   if (!funcOp)
942747d8fb0SKohei Yamaguchi     return success();
943747d8fb0SKohei Yamaguchi   Type argType = funcOp.getArgumentTypes()[argIndex];
944747d8fb0SKohei Yamaguchi 
945747d8fb0SKohei Yamaguchi   return verifyRegionAttribute(op->getLoc(), argType, attribute);
94601178654SLei Zhang }
94701178654SLei Zhang 
94801178654SLei Zhang LogicalResult SPIRVDialect::verifyRegionResultAttribute(
94901178654SLei Zhang     Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
95001178654SLei Zhang     NamedAttribute attribute) {
95101178654SLei Zhang   return op->emitError("cannot attach SPIR-V attributes to region result");
95201178654SLei Zhang }
953