101178654SLei Zhang //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===// 201178654SLei Zhang // 301178654SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM 401178654SLei Zhang // See https://llvm.org/LICENSE.txt for license information. 501178654SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 601178654SLei Zhang // 701178654SLei Zhang //===----------------------------------------------------------------------===// 801178654SLei Zhang // 901178654SLei Zhang // This file defines the SPIR-V dialect in MLIR. 1001178654SLei Zhang // 1101178654SLei Zhang //===----------------------------------------------------------------------===// 1201178654SLei Zhang 1301178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 149415241cSJakub Kuderski 159415241cSJakub Kuderski #include "SPIRVParsingUtils.h" 169415241cSJakub Kuderski 172dace045SSang Ik Lee #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" 1801178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 1901178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 2001178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 215dce7481SIvan Butygin #include "mlir/Dialect/UB/IR/UBOps.h" 2201178654SLei Zhang #include "mlir/IR/Builders.h" 2301178654SLei Zhang #include "mlir/IR/BuiltinTypes.h" 2401178654SLei Zhang #include "mlir/IR/DialectImplementation.h" 2501178654SLei Zhang #include "mlir/IR/MLIRContext.h" 269eaff423SRiver Riddle #include "mlir/Parser/Parser.h" 2701178654SLei Zhang #include "mlir/Transforms/InliningUtils.h" 2801178654SLei Zhang #include "llvm/ADT/DenseMap.h" 2901178654SLei Zhang #include "llvm/ADT/Sequence.h" 3001178654SLei Zhang #include "llvm/ADT/SetVector.h" 3101178654SLei Zhang #include "llvm/ADT/StringExtras.h" 3201178654SLei Zhang #include "llvm/ADT/StringMap.h" 3301178654SLei Zhang #include "llvm/ADT/TypeSwitch.h" 3401178654SLei Zhang #include "llvm/Support/raw_ostream.h" 3501178654SLei Zhang 3601178654SLei Zhang using namespace mlir; 3701178654SLei Zhang using namespace mlir::spirv; 3801178654SLei Zhang 39485cc55eSStella Laurenzo #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc" 40485cc55eSStella Laurenzo 4101178654SLei Zhang //===----------------------------------------------------------------------===// 4201178654SLei Zhang // InlinerInterface 4301178654SLei Zhang //===----------------------------------------------------------------------===// 4401178654SLei Zhang 455ab6ef75SJakub Kuderski /// Returns true if the given region contains spirv.Return or spirv.ReturnValue 465ab6ef75SJakub Kuderski /// ops. 4701178654SLei Zhang static inline bool containsReturn(Region ®ion) { 4801178654SLei Zhang return llvm::any_of(region, [](Block &block) { 4901178654SLei Zhang Operation *terminator = block.getTerminator(); 5001178654SLei Zhang return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator); 5101178654SLei Zhang }); 5201178654SLei Zhang } 5301178654SLei Zhang 5401178654SLei Zhang namespace { 5501178654SLei Zhang /// This class defines the interface for inlining within the SPIR-V dialect. 5601178654SLei Zhang struct SPIRVInlinerInterface : public DialectInlinerInterface { 5701178654SLei Zhang using DialectInlinerInterface::DialectInlinerInterface; 5801178654SLei Zhang 5901178654SLei Zhang /// All call operations within SPIRV can be inlined. 6001178654SLei Zhang bool isLegalToInline(Operation *call, Operation *callable, 6101178654SLei Zhang bool wouldBeCloned) const final { 6201178654SLei Zhang return true; 6301178654SLei Zhang } 6401178654SLei Zhang 6501178654SLei Zhang /// Returns true if the given region 'src' can be inlined into the region 6601178654SLei Zhang /// 'dest' that is attached to an operation registered to the current dialect. 6701178654SLei Zhang bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 684d67b278SJeff Niu IRMapping &) const final { 695ab6ef75SJakub Kuderski // Return true here when inlining into spirv.func, spirv.mlir.selection, and 705ab6ef75SJakub Kuderski // spirv.mlir.loop operations. 7101178654SLei Zhang auto *op = dest->getParentOp(); 7201178654SLei Zhang return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op); 7301178654SLei Zhang } 7401178654SLei Zhang 7501178654SLei Zhang /// Returns true if the given operation 'op', that is registered to this 7601178654SLei Zhang /// dialect, can be inlined into the region 'dest' that is attached to an 7701178654SLei Zhang /// operation registered to the current dialect. 7801178654SLei Zhang bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 794d67b278SJeff Niu IRMapping &) const final { 8001178654SLei Zhang // TODO: Enable inlining structured control flows with return. 8101178654SLei Zhang if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) && 8201178654SLei Zhang containsReturn(op->getRegion(0))) 8301178654SLei Zhang return false; 8401178654SLei Zhang // TODO: we need to filter OpKill here to avoid inlining it to 8501178654SLei Zhang // a loop continue construct: 8601178654SLei Zhang // https://github.com/KhronosGroup/SPIRV-Headers/issues/86 8701178654SLei Zhang // However OpKill is fragment shader specific and we don't support it yet. 8801178654SLei Zhang return true; 8901178654SLei Zhang } 9001178654SLei Zhang 9101178654SLei Zhang /// Handle the given inlined terminator by replacing it with a new operation 9201178654SLei Zhang /// as necessary. 9301178654SLei Zhang void handleTerminator(Operation *op, Block *newDest) const final { 9401178654SLei Zhang if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) { 9501178654SLei Zhang OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest); 9601178654SLei Zhang op->erase(); 9701178654SLei Zhang } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) { 989a5fb74fSArtem Tyurin OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest, 999a5fb74fSArtem Tyurin retValOp->getOperands()); 1009a5fb74fSArtem Tyurin op->erase(); 10101178654SLei Zhang } 10201178654SLei Zhang } 10301178654SLei Zhang 10401178654SLei Zhang /// Handle the given inlined terminator by replacing it with a new operation 10501178654SLei Zhang /// as necessary. 10626a0b277SMehdi Amini void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { 1075ab6ef75SJakub Kuderski // Only spirv.ReturnValue needs to be handled here. 10801178654SLei Zhang auto retValOp = dyn_cast<spirv::ReturnValueOp>(op); 10901178654SLei Zhang if (!retValOp) 11001178654SLei Zhang return; 11101178654SLei Zhang 11201178654SLei Zhang // Replace the values directly with the return operands. 11301178654SLei Zhang assert(valuesToRepl.size() == 1 && 1145ab6ef75SJakub Kuderski "spirv.ReturnValue expected to only handle one result"); 11590a1632dSJakub Kuderski valuesToRepl.front().replaceAllUsesWith(retValOp.getValue()); 11601178654SLei Zhang } 11701178654SLei Zhang }; 11801178654SLei Zhang } // namespace 11901178654SLei Zhang 12001178654SLei Zhang //===----------------------------------------------------------------------===// 12101178654SLei Zhang // SPIR-V Dialect 12201178654SLei Zhang //===----------------------------------------------------------------------===// 12301178654SLei Zhang 12401178654SLei Zhang void SPIRVDialect::initialize() { 12531bb8efdSRiver Riddle registerAttributes(); 12631bb8efdSRiver Riddle registerTypes(); 12701178654SLei Zhang 12801178654SLei Zhang // Add SPIR-V ops. 12901178654SLei Zhang addOperations< 13001178654SLei Zhang #define GET_OP_LIST 13101178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" 13201178654SLei Zhang >(); 13301178654SLei Zhang 13401178654SLei Zhang addInterfaces<SPIRVInlinerInterface>(); 13501178654SLei Zhang 13601178654SLei Zhang // Allow unknown operations because SPIR-V is extensible. 13701178654SLei Zhang allowUnknownOperations(); 13835d55f28SJustin Fargnoli declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>(); 13901178654SLei Zhang } 14001178654SLei Zhang 14101178654SLei Zhang std::string SPIRVDialect::getAttributeName(Decoration decoration) { 14201178654SLei Zhang return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)); 14301178654SLei Zhang } 14401178654SLei Zhang 14501178654SLei Zhang //===----------------------------------------------------------------------===// 14601178654SLei Zhang // Type Parsing 14701178654SLei Zhang //===----------------------------------------------------------------------===// 14801178654SLei Zhang 14901178654SLei Zhang // Forward declarations. 15001178654SLei Zhang template <typename ValTy> 15122426110SRamkumar Ramachandra static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, 15201178654SLei Zhang DialectAsmParser &parser); 15301178654SLei Zhang template <> 15422426110SRamkumar Ramachandra std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, 15501178654SLei Zhang DialectAsmParser &parser); 15601178654SLei Zhang 15701178654SLei Zhang template <> 15822426110SRamkumar Ramachandra std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, 15901178654SLei Zhang DialectAsmParser &parser); 16001178654SLei Zhang 16101178654SLei Zhang static Type parseAndVerifyType(SPIRVDialect const &dialect, 16201178654SLei Zhang DialectAsmParser &parser) { 16301178654SLei Zhang Type type; 1646842ec42SRiver Riddle SMLoc typeLoc = parser.getCurrentLocation(); 16501178654SLei Zhang if (parser.parseType(type)) 16601178654SLei Zhang return Type(); 16701178654SLei Zhang 16801178654SLei Zhang // Allow SPIR-V dialect types 16901178654SLei Zhang if (&type.getDialect() == &dialect) 17001178654SLei Zhang return type; 17101178654SLei Zhang 17201178654SLei Zhang // Check other allowed types 173c1fa60b4STres Popp if (auto t = llvm::dyn_cast<FloatType>(type)) { 17401178654SLei Zhang if (type.isBF16()) { 17501178654SLei Zhang parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); 17601178654SLei Zhang return Type(); 17701178654SLei Zhang } 178c1fa60b4STres Popp } else if (auto t = llvm::dyn_cast<IntegerType>(type)) { 17901178654SLei Zhang if (!ScalarType::isValid(t)) { 18001178654SLei Zhang parser.emitError(typeLoc, 18101178654SLei Zhang "only 1/8/16/32/64-bit integer type allowed but found ") 18201178654SLei Zhang << type; 18301178654SLei Zhang return Type(); 18401178654SLei Zhang } 185c1fa60b4STres Popp } else if (auto t = llvm::dyn_cast<VectorType>(type)) { 18601178654SLei Zhang if (t.getRank() != 1) { 18701178654SLei Zhang parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; 18801178654SLei Zhang return Type(); 18901178654SLei Zhang } 19001178654SLei Zhang if (t.getNumElements() > 4) { 19101178654SLei Zhang parser.emitError( 19201178654SLei Zhang typeLoc, "vector length has to be less than or equal to 4 but found ") 19301178654SLei Zhang << t.getNumElements(); 19401178654SLei Zhang return Type(); 19501178654SLei Zhang } 19601178654SLei Zhang } else { 19701178654SLei Zhang parser.emitError(typeLoc, "cannot use ") 19801178654SLei Zhang << type << " to compose SPIR-V types"; 19901178654SLei Zhang return Type(); 20001178654SLei Zhang } 20101178654SLei Zhang 20201178654SLei Zhang return type; 20301178654SLei Zhang } 20401178654SLei Zhang 20501178654SLei Zhang static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, 20601178654SLei Zhang DialectAsmParser &parser) { 20701178654SLei Zhang Type type; 2086842ec42SRiver Riddle SMLoc typeLoc = parser.getCurrentLocation(); 20901178654SLei Zhang if (parser.parseType(type)) 21001178654SLei Zhang return Type(); 21101178654SLei Zhang 212c1fa60b4STres Popp if (auto t = llvm::dyn_cast<VectorType>(type)) { 21301178654SLei Zhang if (t.getRank() != 1) { 21401178654SLei Zhang parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; 21501178654SLei Zhang return Type(); 21601178654SLei Zhang } 21701178654SLei Zhang if (t.getNumElements() > 4 || t.getNumElements() < 2) { 21801178654SLei Zhang parser.emitError(typeLoc, 21901178654SLei Zhang "matrix columns size has to be less than or equal " 22001178654SLei Zhang "to 4 and greater than or equal 2, but found ") 22101178654SLei Zhang << t.getNumElements(); 22201178654SLei Zhang return Type(); 22301178654SLei Zhang } 22401178654SLei Zhang 225c1fa60b4STres Popp if (!llvm::isa<FloatType>(t.getElementType())) { 22601178654SLei Zhang parser.emitError(typeLoc, "matrix columns' elements must be of " 22701178654SLei Zhang "Float type, got ") 22801178654SLei Zhang << t.getElementType(); 22901178654SLei Zhang return Type(); 23001178654SLei Zhang } 23101178654SLei Zhang } else { 23201178654SLei Zhang parser.emitError(typeLoc, "matrix must be composed using vector " 23301178654SLei Zhang "type, got ") 23401178654SLei Zhang << type; 23501178654SLei Zhang return Type(); 23601178654SLei Zhang } 23701178654SLei Zhang 23801178654SLei Zhang return type; 23901178654SLei Zhang } 24001178654SLei Zhang 2412ef24139SWeiwei Li static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, 2422ef24139SWeiwei Li DialectAsmParser &parser) { 2432ef24139SWeiwei Li Type type; 2446842ec42SRiver Riddle SMLoc typeLoc = parser.getCurrentLocation(); 2452ef24139SWeiwei Li if (parser.parseType(type)) 2462ef24139SWeiwei Li return Type(); 2472ef24139SWeiwei Li 248c1fa60b4STres Popp if (!llvm::isa<ImageType>(type)) { 2492ef24139SWeiwei Li parser.emitError(typeLoc, 2502ef24139SWeiwei Li "sampled image must be composed using image type, got ") 2512ef24139SWeiwei Li << type; 2522ef24139SWeiwei Li return Type(); 2532ef24139SWeiwei Li } 2542ef24139SWeiwei Li 2552ef24139SWeiwei Li return type; 2562ef24139SWeiwei Li } 2572ef24139SWeiwei Li 25801178654SLei Zhang /// Parses an optional `, stride = N` assembly segment. If no parsing failure 25901178654SLei Zhang /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if 26001178654SLei Zhang /// missing. 26101178654SLei Zhang static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, 26201178654SLei Zhang DialectAsmParser &parser, 26301178654SLei Zhang unsigned &stride) { 26401178654SLei Zhang if (failed(parser.parseOptionalComma())) { 26501178654SLei Zhang stride = 0; 26601178654SLei Zhang return success(); 26701178654SLei Zhang } 26801178654SLei Zhang 26901178654SLei Zhang if (parser.parseKeyword("stride") || parser.parseEqual()) 27001178654SLei Zhang return failure(); 27101178654SLei Zhang 2726842ec42SRiver Riddle SMLoc strideLoc = parser.getCurrentLocation(); 27322426110SRamkumar Ramachandra std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser); 27401178654SLei Zhang if (!optStride) 27501178654SLei Zhang return failure(); 27601178654SLei Zhang 2776d5fc1e3SKazu Hirata if (!(stride = *optStride)) { 27801178654SLei Zhang parser.emitError(strideLoc, "ArrayStride must be greater than zero"); 27901178654SLei Zhang return failure(); 28001178654SLei Zhang } 28101178654SLei Zhang return success(); 28201178654SLei Zhang } 28301178654SLei Zhang 28401178654SLei Zhang // element-type ::= integer-type 28501178654SLei Zhang // | floating-point-type 28601178654SLei Zhang // | vector-type 28701178654SLei Zhang // | spirv-type 28801178654SLei Zhang // 2895ab6ef75SJakub Kuderski // array-type ::= `!spirv.array` `<` integer-literal `x` element-type 29001178654SLei Zhang // (`,` `stride` `=` integer-literal)? `>` 29101178654SLei Zhang static Type parseArrayType(SPIRVDialect const &dialect, 29201178654SLei Zhang DialectAsmParser &parser) { 29301178654SLei Zhang if (parser.parseLess()) 29401178654SLei Zhang return Type(); 29501178654SLei Zhang 29601178654SLei Zhang SmallVector<int64_t, 1> countDims; 2976842ec42SRiver Riddle SMLoc countLoc = parser.getCurrentLocation(); 29801178654SLei Zhang if (parser.parseDimensionList(countDims, /*allowDynamic=*/false)) 29901178654SLei Zhang return Type(); 30001178654SLei Zhang if (countDims.size() != 1) { 30101178654SLei Zhang parser.emitError(countLoc, 30201178654SLei Zhang "expected single integer for array element count"); 30301178654SLei Zhang return Type(); 30401178654SLei Zhang } 30501178654SLei Zhang 30601178654SLei Zhang // According to the SPIR-V spec: 30701178654SLei Zhang // "Length is the number of elements in the array. It must be at least 1." 30801178654SLei Zhang int64_t count = countDims[0]; 30901178654SLei Zhang if (count == 0) { 31001178654SLei Zhang parser.emitError(countLoc, "expected array length greater than 0"); 31101178654SLei Zhang return Type(); 31201178654SLei Zhang } 31301178654SLei Zhang 31401178654SLei Zhang Type elementType = parseAndVerifyType(dialect, parser); 31501178654SLei Zhang if (!elementType) 31601178654SLei Zhang return Type(); 31701178654SLei Zhang 31801178654SLei Zhang unsigned stride = 0; 31901178654SLei Zhang if (failed(parseOptionalArrayStride(dialect, parser, stride))) 32001178654SLei Zhang return Type(); 32101178654SLei Zhang 32201178654SLei Zhang if (parser.parseGreater()) 32301178654SLei Zhang return Type(); 32401178654SLei Zhang return ArrayType::get(elementType, count, stride); 32501178654SLei Zhang } 32601178654SLei Zhang 3271d515978SJakub Kuderski // cooperative-matrix-type ::= 3284ba61f5aSJakub Kuderski // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,` 3294ba61f5aSJakub Kuderski // scope `,` use `>` 33001178654SLei Zhang static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, 33101178654SLei Zhang DialectAsmParser &parser) { 33201178654SLei Zhang if (parser.parseLess()) 3334ba61f5aSJakub Kuderski return {}; 3344ba61f5aSJakub Kuderski 3354ba61f5aSJakub Kuderski SmallVector<int64_t, 2> dims; 3364ba61f5aSJakub Kuderski SMLoc countLoc = parser.getCurrentLocation(); 3374ba61f5aSJakub Kuderski if (parser.parseDimensionList(dims, /*allowDynamic=*/false)) 3384ba61f5aSJakub Kuderski return {}; 3394ba61f5aSJakub Kuderski 3404ba61f5aSJakub Kuderski if (dims.size() != 2) { 3414ba61f5aSJakub Kuderski parser.emitError(countLoc, "expected row and column count"); 3424ba61f5aSJakub Kuderski return {}; 3434ba61f5aSJakub Kuderski } 3444ba61f5aSJakub Kuderski 3454ba61f5aSJakub Kuderski auto elementTy = parseAndVerifyType(dialect, parser); 3464ba61f5aSJakub Kuderski if (!elementTy) 3474ba61f5aSJakub Kuderski return {}; 3484ba61f5aSJakub Kuderski 3494ba61f5aSJakub Kuderski Scope scope; 3509415241cSJakub Kuderski if (parser.parseComma() || 3519415241cSJakub Kuderski spirv::parseEnumKeywordAttr(scope, parser, "scope <id>")) 3524ba61f5aSJakub Kuderski return {}; 3534ba61f5aSJakub Kuderski 3544ba61f5aSJakub Kuderski CooperativeMatrixUseKHR use; 3559415241cSJakub Kuderski if (parser.parseComma() || 3569415241cSJakub Kuderski spirv::parseEnumKeywordAttr(use, parser, "use <id>")) 3574ba61f5aSJakub Kuderski return {}; 3584ba61f5aSJakub Kuderski 3594ba61f5aSJakub Kuderski if (parser.parseGreater()) 3604ba61f5aSJakub Kuderski return {}; 3614ba61f5aSJakub Kuderski 3624ba61f5aSJakub Kuderski return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use); 3634ba61f5aSJakub Kuderski } 3644ba61f5aSJakub Kuderski 36501178654SLei Zhang // TODO: Reorder methods to be utilities first and parse*Type 36601178654SLei Zhang // methods in alphabetical order 36701178654SLei Zhang // 36801178654SLei Zhang // storage-class ::= `UniformConstant` 36901178654SLei Zhang // | `Uniform` 37001178654SLei Zhang // | `Workgroup` 37101178654SLei Zhang // | <and other storage classes...> 37201178654SLei Zhang // 3735ab6ef75SJakub Kuderski // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>` 37401178654SLei Zhang static Type parsePointerType(SPIRVDialect const &dialect, 37501178654SLei Zhang DialectAsmParser &parser) { 37601178654SLei Zhang if (parser.parseLess()) 37701178654SLei Zhang return Type(); 37801178654SLei Zhang 37901178654SLei Zhang auto pointeeType = parseAndVerifyType(dialect, parser); 38001178654SLei Zhang if (!pointeeType) 38101178654SLei Zhang return Type(); 38201178654SLei Zhang 38301178654SLei Zhang StringRef storageClassSpec; 3846842ec42SRiver Riddle SMLoc storageClassLoc = parser.getCurrentLocation(); 38501178654SLei Zhang if (parser.parseComma() || parser.parseKeyword(&storageClassSpec)) 38601178654SLei Zhang return Type(); 38701178654SLei Zhang 38801178654SLei Zhang auto storageClass = symbolizeStorageClass(storageClassSpec); 38901178654SLei Zhang if (!storageClass) { 39001178654SLei Zhang parser.emitError(storageClassLoc, "unknown storage class: ") 39101178654SLei Zhang << storageClassSpec; 39201178654SLei Zhang return Type(); 39301178654SLei Zhang } 39401178654SLei Zhang if (parser.parseGreater()) 39501178654SLei Zhang return Type(); 39601178654SLei Zhang return PointerType::get(pointeeType, *storageClass); 39701178654SLei Zhang } 39801178654SLei Zhang 3995ab6ef75SJakub Kuderski // runtime-array-type ::= `!spirv.rtarray` `<` element-type 40001178654SLei Zhang // (`,` `stride` `=` integer-literal)? `>` 40101178654SLei Zhang static Type parseRuntimeArrayType(SPIRVDialect const &dialect, 40201178654SLei Zhang DialectAsmParser &parser) { 40301178654SLei Zhang if (parser.parseLess()) 40401178654SLei Zhang return Type(); 40501178654SLei Zhang 40601178654SLei Zhang Type elementType = parseAndVerifyType(dialect, parser); 40701178654SLei Zhang if (!elementType) 40801178654SLei Zhang return Type(); 40901178654SLei Zhang 41001178654SLei Zhang unsigned stride = 0; 41101178654SLei Zhang if (failed(parseOptionalArrayStride(dialect, parser, stride))) 41201178654SLei Zhang return Type(); 41301178654SLei Zhang 41401178654SLei Zhang if (parser.parseGreater()) 41501178654SLei Zhang return Type(); 41601178654SLei Zhang return RuntimeArrayType::get(elementType, stride); 41701178654SLei Zhang } 41801178654SLei Zhang 4195ab6ef75SJakub Kuderski // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>` 42001178654SLei Zhang static Type parseMatrixType(SPIRVDialect const &dialect, 42101178654SLei Zhang DialectAsmParser &parser) { 42201178654SLei Zhang if (parser.parseLess()) 42301178654SLei Zhang return Type(); 42401178654SLei Zhang 42501178654SLei Zhang SmallVector<int64_t, 1> countDims; 4266842ec42SRiver Riddle SMLoc countLoc = parser.getCurrentLocation(); 42701178654SLei Zhang if (parser.parseDimensionList(countDims, /*allowDynamic=*/false)) 42801178654SLei Zhang return Type(); 42901178654SLei Zhang if (countDims.size() != 1) { 43001178654SLei Zhang parser.emitError(countLoc, "expected single unsigned " 43101178654SLei Zhang "integer for number of columns"); 43201178654SLei Zhang return Type(); 43301178654SLei Zhang } 43401178654SLei Zhang 43501178654SLei Zhang int64_t columnCount = countDims[0]; 43601178654SLei Zhang // According to the specification, Matrices can have 2, 3, or 4 columns 43701178654SLei Zhang if (columnCount < 2 || columnCount > 4) { 43801178654SLei Zhang parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 " 43901178654SLei Zhang "columns"); 44001178654SLei Zhang return Type(); 44101178654SLei Zhang } 44201178654SLei Zhang 44301178654SLei Zhang Type columnType = parseAndVerifyMatrixType(dialect, parser); 44401178654SLei Zhang if (!columnType) 44501178654SLei Zhang return Type(); 44601178654SLei Zhang 44701178654SLei Zhang if (parser.parseGreater()) 44801178654SLei Zhang return Type(); 44901178654SLei Zhang 45001178654SLei Zhang return MatrixType::get(columnType, columnCount); 45101178654SLei Zhang } 45201178654SLei Zhang 45301178654SLei Zhang // Specialize this function to parse each of the parameters that define an 45401178654SLei Zhang // ImageType. By default it assumes this is an enum type. 45501178654SLei Zhang template <typename ValTy> 45622426110SRamkumar Ramachandra static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, 45701178654SLei Zhang DialectAsmParser &parser) { 45801178654SLei Zhang StringRef enumSpec; 4596842ec42SRiver Riddle SMLoc enumLoc = parser.getCurrentLocation(); 46001178654SLei Zhang if (parser.parseKeyword(&enumSpec)) { 4611a36588eSKazu Hirata return std::nullopt; 46201178654SLei Zhang } 46301178654SLei Zhang 46401178654SLei Zhang auto val = spirv::symbolizeEnum<ValTy>(enumSpec); 46501178654SLei Zhang if (!val) 46601178654SLei Zhang parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'"; 46701178654SLei Zhang return val; 46801178654SLei Zhang } 46901178654SLei Zhang 47001178654SLei Zhang template <> 47122426110SRamkumar Ramachandra std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, 47201178654SLei Zhang DialectAsmParser &parser) { 47301178654SLei Zhang // TODO: Further verify that the element type can be sampled 47401178654SLei Zhang auto ty = parseAndVerifyType(dialect, parser); 47501178654SLei Zhang if (!ty) 4761a36588eSKazu Hirata return std::nullopt; 47701178654SLei Zhang return ty; 47801178654SLei Zhang } 47901178654SLei Zhang 48001178654SLei Zhang template <typename IntTy> 48122426110SRamkumar Ramachandra static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect, 48201178654SLei Zhang DialectAsmParser &parser) { 48301178654SLei Zhang IntTy offsetVal = std::numeric_limits<IntTy>::max(); 48401178654SLei Zhang if (parser.parseInteger(offsetVal)) 4851a36588eSKazu Hirata return std::nullopt; 48601178654SLei Zhang return offsetVal; 48701178654SLei Zhang } 48801178654SLei Zhang 48901178654SLei Zhang template <> 49022426110SRamkumar Ramachandra std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, 49101178654SLei Zhang DialectAsmParser &parser) { 49201178654SLei Zhang return parseAndVerifyInteger<unsigned>(dialect, parser); 49301178654SLei Zhang } 49401178654SLei Zhang 49501178654SLei Zhang namespace { 49601178654SLei Zhang // Functor object to parse a comma separated list of specs. The function 49701178654SLei Zhang // parseAndVerify does the actual parsing and verification of individual 49801178654SLei Zhang // elements. This is a functor since parsing the last element of the list 49901178654SLei Zhang // (termination condition) needs partial specialization. 500b7f93c28SJeff Niu template <typename ParseType, typename... Args> 501b7f93c28SJeff Niu struct ParseCommaSeparatedList { 50222426110SRamkumar Ramachandra std::optional<std::tuple<ParseType, Args...>> 50301178654SLei Zhang operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { 50401178654SLei Zhang auto parseVal = parseAndVerify<ParseType>(dialect, parser); 50501178654SLei Zhang if (!parseVal) 5061a36588eSKazu Hirata return std::nullopt; 50701178654SLei Zhang 50801178654SLei Zhang auto numArgs = std::tuple_size<std::tuple<Args...>>::value; 50901178654SLei Zhang if (numArgs != 0 && failed(parser.parseComma())) 5101a36588eSKazu Hirata return std::nullopt; 51101178654SLei Zhang auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser); 51201178654SLei Zhang if (!remainingValues) 5131a36588eSKazu Hirata return std::nullopt; 514c27d8152SKazu Hirata return std::tuple_cat(std::tuple<ParseType>(parseVal.value()), 515c27d8152SKazu Hirata remainingValues.value()); 51601178654SLei Zhang } 51701178654SLei Zhang }; 51801178654SLei Zhang 51901178654SLei Zhang // Partial specialization of the function to parse a comma separated list of 52001178654SLei Zhang // specs to parse the last element of the list. 521b7f93c28SJeff Niu template <typename ParseType> 522b7f93c28SJeff Niu struct ParseCommaSeparatedList<ParseType> { 52322426110SRamkumar Ramachandra std::optional<std::tuple<ParseType>> 52422426110SRamkumar Ramachandra operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { 52501178654SLei Zhang if (auto value = parseAndVerify<ParseType>(dialect, parser)) 5266d5fc1e3SKazu Hirata return std::tuple<ParseType>(*value); 5271a36588eSKazu Hirata return std::nullopt; 52801178654SLei Zhang } 52901178654SLei Zhang }; 53001178654SLei Zhang } // namespace 53101178654SLei Zhang 53201178654SLei Zhang // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...> 53301178654SLei Zhang // 53401178654SLei Zhang // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown` 53501178654SLei Zhang // 53601178654SLei Zhang // arrayed-info ::= `NonArrayed` | `Arrayed` 53701178654SLei Zhang // 53801178654SLei Zhang // sampling-info ::= `SingleSampled` | `MultiSampled` 53901178654SLei Zhang // 54001178654SLei Zhang // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler` 54101178654SLei Zhang // 54201178654SLei Zhang // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...> 54301178654SLei Zhang // 5445ab6ef75SJakub Kuderski // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,` 54501178654SLei Zhang // arrayed-info `,` sampling-info `,` 54601178654SLei Zhang // sampler-use-info `,` format `>` 54701178654SLei Zhang static Type parseImageType(SPIRVDialect const &dialect, 54801178654SLei Zhang DialectAsmParser &parser) { 54901178654SLei Zhang if (parser.parseLess()) 55001178654SLei Zhang return Type(); 55101178654SLei Zhang 55201178654SLei Zhang auto value = 55301178654SLei Zhang ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 55401178654SLei Zhang ImageSamplingInfo, ImageSamplerUseInfo, 55501178654SLei Zhang ImageFormat>{}(dialect, parser); 55601178654SLei Zhang if (!value) 55701178654SLei Zhang return Type(); 55801178654SLei Zhang 55901178654SLei Zhang if (parser.parseGreater()) 56001178654SLei Zhang return Type(); 5616d5fc1e3SKazu Hirata return ImageType::get(*value); 56201178654SLei Zhang } 56301178654SLei Zhang 5645ab6ef75SJakub Kuderski // sampledImage-type :: = `!spirv.sampledImage<` image-type `>` 5652ef24139SWeiwei Li static Type parseSampledImageType(SPIRVDialect const &dialect, 5662ef24139SWeiwei Li DialectAsmParser &parser) { 5672ef24139SWeiwei Li if (parser.parseLess()) 5682ef24139SWeiwei Li return Type(); 5692ef24139SWeiwei Li 5702ef24139SWeiwei Li Type parsedType = parseAndVerifySampledImageType(dialect, parser); 5712ef24139SWeiwei Li if (!parsedType) 5722ef24139SWeiwei Li return Type(); 5732ef24139SWeiwei Li 5742ef24139SWeiwei Li if (parser.parseGreater()) 5752ef24139SWeiwei Li return Type(); 5762ef24139SWeiwei Li return SampledImageType::get(parsedType); 5772ef24139SWeiwei Li } 5782ef24139SWeiwei Li 57901178654SLei Zhang // Parse decorations associated with a member. 58001178654SLei Zhang static ParseResult parseStructMemberDecorations( 58101178654SLei Zhang SPIRVDialect const &dialect, DialectAsmParser &parser, 58201178654SLei Zhang ArrayRef<Type> memberTypes, 58301178654SLei Zhang SmallVectorImpl<StructType::OffsetInfo> &offsetInfo, 58401178654SLei Zhang SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) { 58501178654SLei Zhang 58601178654SLei Zhang // Check if the first element is offset. 5876842ec42SRiver Riddle SMLoc offsetLoc = parser.getCurrentLocation(); 58801178654SLei Zhang StructType::OffsetInfo offset = 0; 58901178654SLei Zhang OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset); 5909750648cSKazu Hirata if (offsetParseResult.has_value()) { 59101178654SLei Zhang if (failed(*offsetParseResult)) 59201178654SLei Zhang return failure(); 59301178654SLei Zhang 59401178654SLei Zhang if (offsetInfo.size() != memberTypes.size() - 1) { 59501178654SLei Zhang return parser.emitError(offsetLoc, 59601178654SLei Zhang "offset specification must be given for " 59701178654SLei Zhang "all members"); 59801178654SLei Zhang } 59901178654SLei Zhang offsetInfo.push_back(offset); 60001178654SLei Zhang } 60101178654SLei Zhang 60201178654SLei Zhang // Check for no spirv::Decorations. 60301178654SLei Zhang if (succeeded(parser.parseOptionalRSquare())) 60401178654SLei Zhang return success(); 60501178654SLei Zhang 60601178654SLei Zhang // If there was an offset, make sure to parse the comma. 6079750648cSKazu Hirata if (offsetParseResult.has_value() && parser.parseComma()) 60801178654SLei Zhang return failure(); 60901178654SLei Zhang 61001178654SLei Zhang // Check for spirv::Decorations. 611167bbfcbSJakub Tucholski auto parseDecorations = [&]() { 61201178654SLei Zhang auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser); 61301178654SLei Zhang if (!memberDecoration) 61401178654SLei Zhang return failure(); 61501178654SLei Zhang 61601178654SLei Zhang // Parse member decoration value if it exists. 61701178654SLei Zhang if (succeeded(parser.parseOptionalEqual())) { 61801178654SLei Zhang auto memberDecorationValue = 61901178654SLei Zhang parseAndVerifyInteger<uint32_t>(dialect, parser); 62001178654SLei Zhang 62101178654SLei Zhang if (!memberDecorationValue) 62201178654SLei Zhang return failure(); 62301178654SLei Zhang 62401178654SLei Zhang memberDecorationInfo.emplace_back( 62501178654SLei Zhang static_cast<uint32_t>(memberTypes.size() - 1), 1, 626c27d8152SKazu Hirata memberDecoration.value(), memberDecorationValue.value()); 62701178654SLei Zhang } else { 62801178654SLei Zhang memberDecorationInfo.emplace_back( 62901178654SLei Zhang static_cast<uint32_t>(memberTypes.size() - 1), 0, 630c27d8152SKazu Hirata memberDecoration.value(), 0); 63101178654SLei Zhang } 632167bbfcbSJakub Tucholski return success(); 633167bbfcbSJakub Tucholski }; 634167bbfcbSJakub Tucholski if (failed(parser.parseCommaSeparatedList(parseDecorations)) || 635167bbfcbSJakub Tucholski failed(parser.parseRSquare())) 636167bbfcbSJakub Tucholski return failure(); 63701178654SLei Zhang 638167bbfcbSJakub Tucholski return success(); 63901178654SLei Zhang } 64001178654SLei Zhang 64101178654SLei Zhang // struct-member-decoration ::= integer-literal? spirv-decoration* 64201178654SLei Zhang // struct-type ::= 6435ab6ef75SJakub Kuderski // `!spirv.struct<` (id `,`)? 64401178654SLei Zhang // `(` 64501178654SLei Zhang // (spirv-type (`[` struct-member-decoration `]`)?)* 64601178654SLei Zhang // `)>` 64701178654SLei Zhang static Type parseStructType(SPIRVDialect const &dialect, 64801178654SLei Zhang DialectAsmParser &parser) { 64901178654SLei Zhang // TODO: This function is quite lengthy. Break it down into smaller chunks. 65001178654SLei Zhang 65101178654SLei Zhang if (parser.parseLess()) 65201178654SLei Zhang return Type(); 65301178654SLei Zhang 65401178654SLei Zhang StringRef identifier; 655b121c266SMarkus Böck FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse; 65601178654SLei Zhang 65701178654SLei Zhang // Check if this is an identified struct type. 65801178654SLei Zhang if (succeeded(parser.parseOptionalKeyword(&identifier))) { 65901178654SLei Zhang // Check if this is a possible recursive reference. 660b121c266SMarkus Böck auto structType = 661b121c266SMarkus Böck StructType::getIdentified(dialect.getContext(), identifier); 662b121c266SMarkus Böck cyclicParse = parser.tryStartCyclicParse(structType); 66301178654SLei Zhang if (succeeded(parser.parseOptionalGreater())) { 664b121c266SMarkus Böck if (succeeded(cyclicParse)) { 66501178654SLei Zhang parser.emitError( 66601178654SLei Zhang parser.getNameLoc(), 66701178654SLei Zhang "recursive struct reference not nested in struct definition"); 66801178654SLei Zhang 66901178654SLei Zhang return Type(); 67001178654SLei Zhang } 67101178654SLei Zhang 672b121c266SMarkus Böck return structType; 67301178654SLei Zhang } 67401178654SLei Zhang 67501178654SLei Zhang if (failed(parser.parseComma())) 67601178654SLei Zhang return Type(); 67701178654SLei Zhang 678b121c266SMarkus Böck if (failed(cyclicParse)) { 67901178654SLei Zhang parser.emitError(parser.getNameLoc(), 68001178654SLei Zhang "identifier already used for an enclosing struct"); 681b121c266SMarkus Böck return Type(); 68201178654SLei Zhang } 68301178654SLei Zhang } 68401178654SLei Zhang 68501178654SLei Zhang if (failed(parser.parseLParen())) 686b121c266SMarkus Böck return Type(); 68701178654SLei Zhang 68801178654SLei Zhang if (succeeded(parser.parseOptionalRParen()) && 68901178654SLei Zhang succeeded(parser.parseOptionalGreater())) { 69001178654SLei Zhang return StructType::getEmpty(dialect.getContext(), identifier); 69101178654SLei Zhang } 69201178654SLei Zhang 69301178654SLei Zhang StructType idStructTy; 69401178654SLei Zhang 69501178654SLei Zhang if (!identifier.empty()) 69601178654SLei Zhang idStructTy = StructType::getIdentified(dialect.getContext(), identifier); 69701178654SLei Zhang 69801178654SLei Zhang SmallVector<Type, 4> memberTypes; 69901178654SLei Zhang SmallVector<StructType::OffsetInfo, 4> offsetInfo; 70001178654SLei Zhang SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo; 70101178654SLei Zhang 70201178654SLei Zhang do { 70301178654SLei Zhang Type memberType; 70401178654SLei Zhang if (parser.parseType(memberType)) 705b121c266SMarkus Böck return Type(); 70601178654SLei Zhang memberTypes.push_back(memberType); 70701178654SLei Zhang 70801178654SLei Zhang if (succeeded(parser.parseOptionalLSquare())) 70901178654SLei Zhang if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo, 71001178654SLei Zhang memberDecorationInfo)) 711b121c266SMarkus Böck return Type(); 71201178654SLei Zhang } while (succeeded(parser.parseOptionalComma())); 71301178654SLei Zhang 71401178654SLei Zhang if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) { 71501178654SLei Zhang parser.emitError(parser.getNameLoc(), 71601178654SLei Zhang "offset specification must be given for all members"); 717b121c266SMarkus Böck return Type(); 71801178654SLei Zhang } 71901178654SLei Zhang 72001178654SLei Zhang if (failed(parser.parseRParen()) || failed(parser.parseGreater())) 721b121c266SMarkus Böck return Type(); 72201178654SLei Zhang 72301178654SLei Zhang if (!identifier.empty()) { 72401178654SLei Zhang if (failed(idStructTy.trySetBody(memberTypes, offsetInfo, 72501178654SLei Zhang memberDecorationInfo))) 72601178654SLei Zhang return Type(); 72701178654SLei Zhang return idStructTy; 72801178654SLei Zhang } 72901178654SLei Zhang 73001178654SLei Zhang return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); 73101178654SLei Zhang } 73201178654SLei Zhang 73301178654SLei Zhang // spirv-type ::= array-type 73401178654SLei Zhang // | element-type 73501178654SLei Zhang // | image-type 73601178654SLei Zhang // | pointer-type 73701178654SLei Zhang // | runtime-array-type 7382ef24139SWeiwei Li // | sampled-image-type 73901178654SLei Zhang // | struct-type 74001178654SLei Zhang Type SPIRVDialect::parseType(DialectAsmParser &parser) const { 74101178654SLei Zhang StringRef keyword; 74201178654SLei Zhang if (parser.parseKeyword(&keyword)) 74301178654SLei Zhang return Type(); 74401178654SLei Zhang 74501178654SLei Zhang if (keyword == "array") 74601178654SLei Zhang return parseArrayType(*this, parser); 7474ba61f5aSJakub Kuderski if (keyword == "coopmatrix") 74801178654SLei Zhang return parseCooperativeMatrixType(*this, parser); 74901178654SLei Zhang if (keyword == "image") 75001178654SLei Zhang return parseImageType(*this, parser); 75101178654SLei Zhang if (keyword == "ptr") 75201178654SLei Zhang return parsePointerType(*this, parser); 75301178654SLei Zhang if (keyword == "rtarray") 75401178654SLei Zhang return parseRuntimeArrayType(*this, parser); 7552ef24139SWeiwei Li if (keyword == "sampled_image") 7562ef24139SWeiwei Li return parseSampledImageType(*this, parser); 75701178654SLei Zhang if (keyword == "struct") 75801178654SLei Zhang return parseStructType(*this, parser); 75901178654SLei Zhang if (keyword == "matrix") 76001178654SLei Zhang return parseMatrixType(*this, parser); 76101178654SLei Zhang parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword; 76201178654SLei Zhang return Type(); 76301178654SLei Zhang } 76401178654SLei Zhang 76501178654SLei Zhang //===----------------------------------------------------------------------===// 76601178654SLei Zhang // Type Printing 76701178654SLei Zhang //===----------------------------------------------------------------------===// 76801178654SLei Zhang 76901178654SLei Zhang static void print(ArrayType type, DialectAsmPrinter &os) { 77001178654SLei Zhang os << "array<" << type.getNumElements() << " x " << type.getElementType(); 77101178654SLei Zhang if (unsigned stride = type.getArrayStride()) 77201178654SLei Zhang os << ", stride=" << stride; 77301178654SLei Zhang os << ">"; 77401178654SLei Zhang } 77501178654SLei Zhang 77601178654SLei Zhang static void print(RuntimeArrayType type, DialectAsmPrinter &os) { 77701178654SLei Zhang os << "rtarray<" << type.getElementType(); 77801178654SLei Zhang if (unsigned stride = type.getArrayStride()) 77901178654SLei Zhang os << ", stride=" << stride; 78001178654SLei Zhang os << ">"; 78101178654SLei Zhang } 78201178654SLei Zhang 78301178654SLei Zhang static void print(PointerType type, DialectAsmPrinter &os) { 78401178654SLei Zhang os << "ptr<" << type.getPointeeType() << ", " 78501178654SLei Zhang << stringifyStorageClass(type.getStorageClass()) << ">"; 78601178654SLei Zhang } 78701178654SLei Zhang 78801178654SLei Zhang static void print(ImageType type, DialectAsmPrinter &os) { 78901178654SLei Zhang os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim()) 79001178654SLei Zhang << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", " 79101178654SLei Zhang << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", " 79201178654SLei Zhang << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", " 79301178654SLei Zhang << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", " 79401178654SLei Zhang << stringifyImageFormat(type.getImageFormat()) << ">"; 79501178654SLei Zhang } 79601178654SLei Zhang 7972ef24139SWeiwei Li static void print(SampledImageType type, DialectAsmPrinter &os) { 7982ef24139SWeiwei Li os << "sampled_image<" << type.getImageType() << ">"; 7992ef24139SWeiwei Li } 8002ef24139SWeiwei Li 80101178654SLei Zhang static void print(StructType type, DialectAsmPrinter &os) { 802b121c266SMarkus Böck FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint; 80301178654SLei Zhang 80401178654SLei Zhang os << "struct<"; 80501178654SLei Zhang 80601178654SLei Zhang if (type.isIdentified()) { 80701178654SLei Zhang os << type.getIdentifier(); 80801178654SLei Zhang 809b121c266SMarkus Böck cyclicPrint = os.tryStartCyclicPrint(type); 810b121c266SMarkus Böck if (failed(cyclicPrint)) { 81101178654SLei Zhang os << ">"; 81201178654SLei Zhang return; 81301178654SLei Zhang } 81401178654SLei Zhang 81501178654SLei Zhang os << ", "; 81601178654SLei Zhang } 81701178654SLei Zhang 81801178654SLei Zhang os << "("; 81901178654SLei Zhang 82001178654SLei Zhang auto printMember = [&](unsigned i) { 82101178654SLei Zhang os << type.getElementType(i); 82201178654SLei Zhang SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations; 82301178654SLei Zhang type.getMemberDecorations(i, decorations); 82401178654SLei Zhang if (type.hasOffset() || !decorations.empty()) { 82501178654SLei Zhang os << " ["; 82601178654SLei Zhang if (type.hasOffset()) { 82701178654SLei Zhang os << type.getMemberOffset(i); 82801178654SLei Zhang if (!decorations.empty()) 82901178654SLei Zhang os << ", "; 83001178654SLei Zhang } 83101178654SLei Zhang auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) { 83201178654SLei Zhang os << stringifyDecoration(decoration.decoration); 83301178654SLei Zhang if (decoration.hasValue) { 83401178654SLei Zhang os << "=" << decoration.decorationValue; 83501178654SLei Zhang } 83601178654SLei Zhang }; 83701178654SLei Zhang llvm::interleaveComma(decorations, os, eachFn); 83801178654SLei Zhang os << "]"; 83901178654SLei Zhang } 84001178654SLei Zhang }; 84101178654SLei Zhang llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os, 84201178654SLei Zhang printMember); 84301178654SLei Zhang os << ")>"; 84401178654SLei Zhang } 84501178654SLei Zhang 8464ba61f5aSJakub Kuderski static void print(CooperativeMatrixType type, DialectAsmPrinter &os) { 8474ba61f5aSJakub Kuderski os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x" 8484ba61f5aSJakub Kuderski << type.getElementType() << ", " << type.getScope() << ", " 8494ba61f5aSJakub Kuderski << type.getUse() << ">"; 8504ba61f5aSJakub Kuderski } 8514ba61f5aSJakub Kuderski 85201178654SLei Zhang static void print(MatrixType type, DialectAsmPrinter &os) { 85301178654SLei Zhang os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType(); 85401178654SLei Zhang os << ">"; 85501178654SLei Zhang } 85601178654SLei Zhang 85701178654SLei Zhang void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { 85801178654SLei Zhang TypeSwitch<Type>(type) 859*b719ab4eSAndrea Faulds .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType, 860*b719ab4eSAndrea Faulds ImageType, SampledImageType, StructType, MatrixType>( 861*b719ab4eSAndrea Faulds [&](auto type) { print(type, os); }) 86201178654SLei Zhang .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); 86301178654SLei Zhang } 86401178654SLei Zhang 86501178654SLei Zhang //===----------------------------------------------------------------------===// 86601178654SLei Zhang // Constant 86701178654SLei Zhang //===----------------------------------------------------------------------===// 86801178654SLei Zhang 86901178654SLei Zhang Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, 87001178654SLei Zhang Attribute value, Type type, 87101178654SLei Zhang Location loc) { 8725dce7481SIvan Butygin if (auto poison = dyn_cast<ub::PoisonAttr>(value)) 8735dce7481SIvan Butygin return builder.create<ub::PoisonOp>(loc, type, poison); 8745dce7481SIvan Butygin 87501178654SLei Zhang if (!spirv::ConstantOp::isBuildableWith(type)) 87601178654SLei Zhang return nullptr; 87701178654SLei Zhang 87801178654SLei Zhang return builder.create<spirv::ConstantOp>(loc, type, value); 87901178654SLei Zhang } 88001178654SLei Zhang 88101178654SLei Zhang //===----------------------------------------------------------------------===// 88201178654SLei Zhang // Shader Interface ABI 88301178654SLei Zhang //===----------------------------------------------------------------------===// 88401178654SLei Zhang 88501178654SLei Zhang LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, 88601178654SLei Zhang NamedAttribute attribute) { 8870c7890c8SRiver Riddle StringRef symbol = attribute.getName().strref(); 8880c7890c8SRiver Riddle Attribute attr = attribute.getValue(); 88901178654SLei Zhang 89001178654SLei Zhang if (symbol == spirv::getEntryPointABIAttrName()) { 891c1fa60b4STres Popp if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) { 89201178654SLei Zhang return op->emitError("'") 893a31ff0afSMogball << symbol << "' attribute must be an entry point ABI attribute"; 894a31ff0afSMogball } 89501178654SLei Zhang } else if (symbol == spirv::getTargetEnvAttrName()) { 896c1fa60b4STres Popp if (!llvm::isa<spirv::TargetEnvAttr>(attr)) 89701178654SLei Zhang return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr"; 89801178654SLei Zhang } else { 89901178654SLei Zhang return op->emitError("found unsupported '") 90001178654SLei Zhang << symbol << "' attribute on operation"; 90101178654SLei Zhang } 90201178654SLei Zhang 90301178654SLei Zhang return success(); 90401178654SLei Zhang } 90501178654SLei Zhang 90601178654SLei Zhang /// Verifies the given SPIR-V `attribute` attached to a value of the given 90701178654SLei Zhang /// `valueType` is valid. 90801178654SLei Zhang static LogicalResult verifyRegionAttribute(Location loc, Type valueType, 90901178654SLei Zhang NamedAttribute attribute) { 9100c7890c8SRiver Riddle StringRef symbol = attribute.getName().strref(); 9110c7890c8SRiver Riddle Attribute attr = attribute.getValue(); 91201178654SLei Zhang 913747d8fb0SKohei Yamaguchi if (symbol == spirv::getInterfaceVarABIAttrName()) { 914c1fa60b4STres Popp auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr); 91501178654SLei Zhang if (!varABIAttr) 91601178654SLei Zhang return emitError(loc, "'") 91701178654SLei Zhang << symbol << "' must be a spirv::InterfaceVarABIAttr"; 91801178654SLei Zhang 91901178654SLei Zhang if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat()) 92001178654SLei Zhang return emitError(loc, "'") << symbol 92101178654SLei Zhang << "' attribute cannot specify storage class " 92201178654SLei Zhang "when attaching to a non-scalar value"; 92301178654SLei Zhang return success(); 92401178654SLei Zhang } 925747d8fb0SKohei Yamaguchi if (symbol == spirv::DecorationAttr::name) { 926747d8fb0SKohei Yamaguchi if (!isa<spirv::DecorationAttr>(attr)) 927747d8fb0SKohei Yamaguchi return emitError(loc, "'") 928747d8fb0SKohei Yamaguchi << symbol << "' must be a spirv::DecorationAttr"; 929747d8fb0SKohei Yamaguchi return success(); 930747d8fb0SKohei Yamaguchi } 931747d8fb0SKohei Yamaguchi 932747d8fb0SKohei Yamaguchi return emitError(loc, "found unsupported '") 933747d8fb0SKohei Yamaguchi << symbol << "' attribute on region argument"; 934747d8fb0SKohei Yamaguchi } 93501178654SLei Zhang 93601178654SLei Zhang LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op, 93701178654SLei Zhang unsigned regionIndex, 93801178654SLei Zhang unsigned argIndex, 93901178654SLei Zhang NamedAttribute attribute) { 940747d8fb0SKohei Yamaguchi auto funcOp = dyn_cast<FunctionOpInterface>(op); 941747d8fb0SKohei Yamaguchi if (!funcOp) 942747d8fb0SKohei Yamaguchi return success(); 943747d8fb0SKohei Yamaguchi Type argType = funcOp.getArgumentTypes()[argIndex]; 944747d8fb0SKohei Yamaguchi 945747d8fb0SKohei Yamaguchi return verifyRegionAttribute(op->getLoc(), argType, attribute); 94601178654SLei Zhang } 94701178654SLei Zhang 94801178654SLei Zhang LogicalResult SPIRVDialect::verifyRegionResultAttribute( 94901178654SLei Zhang Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/, 95001178654SLei Zhang NamedAttribute attribute) { 95101178654SLei Zhang return op->emitError("cannot attach SPIR-V attributes to region result"); 95201178654SLei Zhang } 953