xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (revision b719ab4eef634f24605ca7ccd4874338c34e05bd)
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 
15 #include "SPIRVParsingUtils.h"
16 
17 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/Dialect/UB/IR/UBOps.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/DialectImplementation.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/Parser/Parser.h"
27 #include "mlir/Transforms/InliningUtils.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringMap.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 using namespace mlir;
37 using namespace mlir::spirv;
38 
39 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // InlinerInterface
43 //===----------------------------------------------------------------------===//
44 
45 /// Returns true if the given region contains spirv.Return or spirv.ReturnValue
46 /// ops.
47 static inline bool containsReturn(Region &region) {
48   return llvm::any_of(region, [](Block &block) {
49     Operation *terminator = block.getTerminator();
50     return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
51   });
52 }
53 
54 namespace {
55 /// This class defines the interface for inlining within the SPIR-V dialect.
56 struct SPIRVInlinerInterface : public DialectInlinerInterface {
57   using DialectInlinerInterface::DialectInlinerInterface;
58 
59   /// All call operations within SPIRV can be inlined.
60   bool isLegalToInline(Operation *call, Operation *callable,
61                        bool wouldBeCloned) const final {
62     return true;
63   }
64 
65   /// Returns true if the given region 'src' can be inlined into the region
66   /// 'dest' that is attached to an operation registered to the current dialect.
67   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68                        IRMapping &) const final {
69     // Return true here when inlining into spirv.func, spirv.mlir.selection, and
70     // spirv.mlir.loop operations.
71     auto *op = dest->getParentOp();
72     return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
73   }
74 
75   /// Returns true if the given operation 'op', that is registered to this
76   /// dialect, can be inlined into the region 'dest' that is attached to an
77   /// operation registered to the current dialect.
78   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
79                        IRMapping &) const final {
80     // TODO: Enable inlining structured control flows with return.
81     if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
82         containsReturn(op->getRegion(0)))
83       return false;
84     // TODO: we need to filter OpKill here to avoid inlining it to
85     // a loop continue construct:
86     // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
87     // However OpKill is fragment shader specific and we don't support it yet.
88     return true;
89   }
90 
91   /// Handle the given inlined terminator by replacing it with a new operation
92   /// as necessary.
93   void handleTerminator(Operation *op, Block *newDest) const final {
94     if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
95       OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
96       op->erase();
97     } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
98       OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest,
99                                             retValOp->getOperands());
100       op->erase();
101     }
102   }
103 
104   /// Handle the given inlined terminator by replacing it with a new operation
105   /// as necessary.
106   void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
107     // Only spirv.ReturnValue needs to be handled here.
108     auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
109     if (!retValOp)
110       return;
111 
112     // Replace the values directly with the return operands.
113     assert(valuesToRepl.size() == 1 &&
114            "spirv.ReturnValue expected to only handle one result");
115     valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
116   }
117 };
118 } // namespace
119 
120 //===----------------------------------------------------------------------===//
121 // SPIR-V Dialect
122 //===----------------------------------------------------------------------===//
123 
124 void SPIRVDialect::initialize() {
125   registerAttributes();
126   registerTypes();
127 
128   // Add SPIR-V ops.
129   addOperations<
130 #define GET_OP_LIST
131 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
132       >();
133 
134   addInterfaces<SPIRVInlinerInterface>();
135 
136   // Allow unknown operations because SPIR-V is extensible.
137   allowUnknownOperations();
138   declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
139 }
140 
141 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
142   return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
143 }
144 
145 //===----------------------------------------------------------------------===//
146 // Type Parsing
147 //===----------------------------------------------------------------------===//
148 
149 // Forward declarations.
150 template <typename ValTy>
151 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
152                                            DialectAsmParser &parser);
153 template <>
154 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
155                                          DialectAsmParser &parser);
156 
157 template <>
158 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
159                                                  DialectAsmParser &parser);
160 
161 static Type parseAndVerifyType(SPIRVDialect const &dialect,
162                                DialectAsmParser &parser) {
163   Type type;
164   SMLoc typeLoc = parser.getCurrentLocation();
165   if (parser.parseType(type))
166     return Type();
167 
168   // Allow SPIR-V dialect types
169   if (&type.getDialect() == &dialect)
170     return type;
171 
172   // Check other allowed types
173   if (auto t = llvm::dyn_cast<FloatType>(type)) {
174     if (type.isBF16()) {
175       parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
176       return Type();
177     }
178   } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
179     if (!ScalarType::isValid(t)) {
180       parser.emitError(typeLoc,
181                        "only 1/8/16/32/64-bit integer type allowed but found ")
182           << type;
183       return Type();
184     }
185   } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
186     if (t.getRank() != 1) {
187       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
188       return Type();
189     }
190     if (t.getNumElements() > 4) {
191       parser.emitError(
192           typeLoc, "vector length has to be less than or equal to 4 but found ")
193           << t.getNumElements();
194       return Type();
195     }
196   } else {
197     parser.emitError(typeLoc, "cannot use ")
198         << type << " to compose SPIR-V types";
199     return Type();
200   }
201 
202   return type;
203 }
204 
205 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
206                                      DialectAsmParser &parser) {
207   Type type;
208   SMLoc typeLoc = parser.getCurrentLocation();
209   if (parser.parseType(type))
210     return Type();
211 
212   if (auto t = llvm::dyn_cast<VectorType>(type)) {
213     if (t.getRank() != 1) {
214       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
215       return Type();
216     }
217     if (t.getNumElements() > 4 || t.getNumElements() < 2) {
218       parser.emitError(typeLoc,
219                        "matrix columns size has to be less than or equal "
220                        "to 4 and greater than or equal 2, but found ")
221           << t.getNumElements();
222       return Type();
223     }
224 
225     if (!llvm::isa<FloatType>(t.getElementType())) {
226       parser.emitError(typeLoc, "matrix columns' elements must be of "
227                                 "Float type, got ")
228           << t.getElementType();
229       return Type();
230     }
231   } else {
232     parser.emitError(typeLoc, "matrix must be composed using vector "
233                               "type, got ")
234         << type;
235     return Type();
236   }
237 
238   return type;
239 }
240 
241 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
242                                            DialectAsmParser &parser) {
243   Type type;
244   SMLoc typeLoc = parser.getCurrentLocation();
245   if (parser.parseType(type))
246     return Type();
247 
248   if (!llvm::isa<ImageType>(type)) {
249     parser.emitError(typeLoc,
250                      "sampled image must be composed using image type, got ")
251         << type;
252     return Type();
253   }
254 
255   return type;
256 }
257 
258 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
259 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
260 /// missing.
261 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
262                                               DialectAsmParser &parser,
263                                               unsigned &stride) {
264   if (failed(parser.parseOptionalComma())) {
265     stride = 0;
266     return success();
267   }
268 
269   if (parser.parseKeyword("stride") || parser.parseEqual())
270     return failure();
271 
272   SMLoc strideLoc = parser.getCurrentLocation();
273   std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
274   if (!optStride)
275     return failure();
276 
277   if (!(stride = *optStride)) {
278     parser.emitError(strideLoc, "ArrayStride must be greater than zero");
279     return failure();
280   }
281   return success();
282 }
283 
284 // element-type ::= integer-type
285 //                | floating-point-type
286 //                | vector-type
287 //                | spirv-type
288 //
289 // array-type ::= `!spirv.array` `<` integer-literal `x` element-type
290 //                (`,` `stride` `=` integer-literal)? `>`
291 static Type parseArrayType(SPIRVDialect const &dialect,
292                            DialectAsmParser &parser) {
293   if (parser.parseLess())
294     return Type();
295 
296   SmallVector<int64_t, 1> countDims;
297   SMLoc countLoc = parser.getCurrentLocation();
298   if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
299     return Type();
300   if (countDims.size() != 1) {
301     parser.emitError(countLoc,
302                      "expected single integer for array element count");
303     return Type();
304   }
305 
306   // According to the SPIR-V spec:
307   // "Length is the number of elements in the array. It must be at least 1."
308   int64_t count = countDims[0];
309   if (count == 0) {
310     parser.emitError(countLoc, "expected array length greater than 0");
311     return Type();
312   }
313 
314   Type elementType = parseAndVerifyType(dialect, parser);
315   if (!elementType)
316     return Type();
317 
318   unsigned stride = 0;
319   if (failed(parseOptionalArrayStride(dialect, parser, stride)))
320     return Type();
321 
322   if (parser.parseGreater())
323     return Type();
324   return ArrayType::get(elementType, count, stride);
325 }
326 
327 // cooperative-matrix-type ::=
328 //   `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
329 //                           scope `,` use `>`
330 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
331                                        DialectAsmParser &parser) {
332   if (parser.parseLess())
333     return {};
334 
335   SmallVector<int64_t, 2> dims;
336   SMLoc countLoc = parser.getCurrentLocation();
337   if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
338     return {};
339 
340   if (dims.size() != 2) {
341     parser.emitError(countLoc, "expected row and column count");
342     return {};
343   }
344 
345   auto elementTy = parseAndVerifyType(dialect, parser);
346   if (!elementTy)
347     return {};
348 
349   Scope scope;
350   if (parser.parseComma() ||
351       spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
352     return {};
353 
354   CooperativeMatrixUseKHR use;
355   if (parser.parseComma() ||
356       spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
357     return {};
358 
359   if (parser.parseGreater())
360     return {};
361 
362   return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
363 }
364 
365 // TODO: Reorder methods to be utilities first and parse*Type
366 // methods in alphabetical order
367 //
368 // storage-class ::= `UniformConstant`
369 //                 | `Uniform`
370 //                 | `Workgroup`
371 //                 | <and other storage classes...>
372 //
373 // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
374 static Type parsePointerType(SPIRVDialect const &dialect,
375                              DialectAsmParser &parser) {
376   if (parser.parseLess())
377     return Type();
378 
379   auto pointeeType = parseAndVerifyType(dialect, parser);
380   if (!pointeeType)
381     return Type();
382 
383   StringRef storageClassSpec;
384   SMLoc storageClassLoc = parser.getCurrentLocation();
385   if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
386     return Type();
387 
388   auto storageClass = symbolizeStorageClass(storageClassSpec);
389   if (!storageClass) {
390     parser.emitError(storageClassLoc, "unknown storage class: ")
391         << storageClassSpec;
392     return Type();
393   }
394   if (parser.parseGreater())
395     return Type();
396   return PointerType::get(pointeeType, *storageClass);
397 }
398 
399 // runtime-array-type ::= `!spirv.rtarray` `<` element-type
400 //                        (`,` `stride` `=` integer-literal)? `>`
401 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
402                                   DialectAsmParser &parser) {
403   if (parser.parseLess())
404     return Type();
405 
406   Type elementType = parseAndVerifyType(dialect, parser);
407   if (!elementType)
408     return Type();
409 
410   unsigned stride = 0;
411   if (failed(parseOptionalArrayStride(dialect, parser, stride)))
412     return Type();
413 
414   if (parser.parseGreater())
415     return Type();
416   return RuntimeArrayType::get(elementType, stride);
417 }
418 
419 // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
420 static Type parseMatrixType(SPIRVDialect const &dialect,
421                             DialectAsmParser &parser) {
422   if (parser.parseLess())
423     return Type();
424 
425   SmallVector<int64_t, 1> countDims;
426   SMLoc countLoc = parser.getCurrentLocation();
427   if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
428     return Type();
429   if (countDims.size() != 1) {
430     parser.emitError(countLoc, "expected single unsigned "
431                                "integer for number of columns");
432     return Type();
433   }
434 
435   int64_t columnCount = countDims[0];
436   // According to the specification, Matrices can have 2, 3, or 4 columns
437   if (columnCount < 2 || columnCount > 4) {
438     parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
439                                "columns");
440     return Type();
441   }
442 
443   Type columnType = parseAndVerifyMatrixType(dialect, parser);
444   if (!columnType)
445     return Type();
446 
447   if (parser.parseGreater())
448     return Type();
449 
450   return MatrixType::get(columnType, columnCount);
451 }
452 
453 // Specialize this function to parse each of the parameters that define an
454 // ImageType. By default it assumes this is an enum type.
455 template <typename ValTy>
456 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
457                                            DialectAsmParser &parser) {
458   StringRef enumSpec;
459   SMLoc enumLoc = parser.getCurrentLocation();
460   if (parser.parseKeyword(&enumSpec)) {
461     return std::nullopt;
462   }
463 
464   auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
465   if (!val)
466     parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
467   return val;
468 }
469 
470 template <>
471 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
472                                          DialectAsmParser &parser) {
473   // TODO: Further verify that the element type can be sampled
474   auto ty = parseAndVerifyType(dialect, parser);
475   if (!ty)
476     return std::nullopt;
477   return ty;
478 }
479 
480 template <typename IntTy>
481 static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
482                                                   DialectAsmParser &parser) {
483   IntTy offsetVal = std::numeric_limits<IntTy>::max();
484   if (parser.parseInteger(offsetVal))
485     return std::nullopt;
486   return offsetVal;
487 }
488 
489 template <>
490 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
491                                                  DialectAsmParser &parser) {
492   return parseAndVerifyInteger<unsigned>(dialect, parser);
493 }
494 
495 namespace {
496 // Functor object to parse a comma separated list of specs. The function
497 // parseAndVerify does the actual parsing and verification of individual
498 // elements. This is a functor since parsing the last element of the list
499 // (termination condition) needs partial specialization.
500 template <typename ParseType, typename... Args>
501 struct ParseCommaSeparatedList {
502   std::optional<std::tuple<ParseType, Args...>>
503   operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
504     auto parseVal = parseAndVerify<ParseType>(dialect, parser);
505     if (!parseVal)
506       return std::nullopt;
507 
508     auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
509     if (numArgs != 0 && failed(parser.parseComma()))
510       return std::nullopt;
511     auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
512     if (!remainingValues)
513       return std::nullopt;
514     return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
515                           remainingValues.value());
516   }
517 };
518 
519 // Partial specialization of the function to parse a comma separated list of
520 // specs to parse the last element of the list.
521 template <typename ParseType>
522 struct ParseCommaSeparatedList<ParseType> {
523   std::optional<std::tuple<ParseType>>
524   operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
525     if (auto value = parseAndVerify<ParseType>(dialect, parser))
526       return std::tuple<ParseType>(*value);
527     return std::nullopt;
528   }
529 };
530 } // namespace
531 
532 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
533 //
534 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
535 //
536 // arrayed-info ::= `NonArrayed` | `Arrayed`
537 //
538 // sampling-info ::= `SingleSampled` | `MultiSampled`
539 //
540 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` |  `NoSampler`
541 //
542 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
543 //
544 // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
545 //                              arrayed-info `,` sampling-info `,`
546 //                              sampler-use-info `,` format `>`
547 static Type parseImageType(SPIRVDialect const &dialect,
548                            DialectAsmParser &parser) {
549   if (parser.parseLess())
550     return Type();
551 
552   auto value =
553       ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
554                               ImageSamplingInfo, ImageSamplerUseInfo,
555                               ImageFormat>{}(dialect, parser);
556   if (!value)
557     return Type();
558 
559   if (parser.parseGreater())
560     return Type();
561   return ImageType::get(*value);
562 }
563 
564 // sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
565 static Type parseSampledImageType(SPIRVDialect const &dialect,
566                                   DialectAsmParser &parser) {
567   if (parser.parseLess())
568     return Type();
569 
570   Type parsedType = parseAndVerifySampledImageType(dialect, parser);
571   if (!parsedType)
572     return Type();
573 
574   if (parser.parseGreater())
575     return Type();
576   return SampledImageType::get(parsedType);
577 }
578 
579 // Parse decorations associated with a member.
580 static ParseResult parseStructMemberDecorations(
581     SPIRVDialect const &dialect, DialectAsmParser &parser,
582     ArrayRef<Type> memberTypes,
583     SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
584     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
585 
586   // Check if the first element is offset.
587   SMLoc offsetLoc = parser.getCurrentLocation();
588   StructType::OffsetInfo offset = 0;
589   OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
590   if (offsetParseResult.has_value()) {
591     if (failed(*offsetParseResult))
592       return failure();
593 
594     if (offsetInfo.size() != memberTypes.size() - 1) {
595       return parser.emitError(offsetLoc,
596                               "offset specification must be given for "
597                               "all members");
598     }
599     offsetInfo.push_back(offset);
600   }
601 
602   // Check for no spirv::Decorations.
603   if (succeeded(parser.parseOptionalRSquare()))
604     return success();
605 
606   // If there was an offset, make sure to parse the comma.
607   if (offsetParseResult.has_value() && parser.parseComma())
608     return failure();
609 
610   // Check for spirv::Decorations.
611   auto parseDecorations = [&]() {
612     auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
613     if (!memberDecoration)
614       return failure();
615 
616     // Parse member decoration value if it exists.
617     if (succeeded(parser.parseOptionalEqual())) {
618       auto memberDecorationValue =
619           parseAndVerifyInteger<uint32_t>(dialect, parser);
620 
621       if (!memberDecorationValue)
622         return failure();
623 
624       memberDecorationInfo.emplace_back(
625           static_cast<uint32_t>(memberTypes.size() - 1), 1,
626           memberDecoration.value(), memberDecorationValue.value());
627     } else {
628       memberDecorationInfo.emplace_back(
629           static_cast<uint32_t>(memberTypes.size() - 1), 0,
630           memberDecoration.value(), 0);
631     }
632     return success();
633   };
634   if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
635       failed(parser.parseRSquare()))
636     return failure();
637 
638   return success();
639 }
640 
641 // struct-member-decoration ::= integer-literal? spirv-decoration*
642 // struct-type ::=
643 //             `!spirv.struct<` (id `,`)?
644 //                          `(`
645 //                            (spirv-type (`[` struct-member-decoration `]`)?)*
646 //                          `)>`
647 static Type parseStructType(SPIRVDialect const &dialect,
648                             DialectAsmParser &parser) {
649   // TODO: This function is quite lengthy. Break it down into smaller chunks.
650 
651   if (parser.parseLess())
652     return Type();
653 
654   StringRef identifier;
655   FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
656 
657   // Check if this is an identified struct type.
658   if (succeeded(parser.parseOptionalKeyword(&identifier))) {
659     // Check if this is a possible recursive reference.
660     auto structType =
661         StructType::getIdentified(dialect.getContext(), identifier);
662     cyclicParse = parser.tryStartCyclicParse(structType);
663     if (succeeded(parser.parseOptionalGreater())) {
664       if (succeeded(cyclicParse)) {
665         parser.emitError(
666             parser.getNameLoc(),
667             "recursive struct reference not nested in struct definition");
668 
669         return Type();
670       }
671 
672       return structType;
673     }
674 
675     if (failed(parser.parseComma()))
676       return Type();
677 
678     if (failed(cyclicParse)) {
679       parser.emitError(parser.getNameLoc(),
680                        "identifier already used for an enclosing struct");
681       return Type();
682     }
683   }
684 
685   if (failed(parser.parseLParen()))
686     return Type();
687 
688   if (succeeded(parser.parseOptionalRParen()) &&
689       succeeded(parser.parseOptionalGreater())) {
690     return StructType::getEmpty(dialect.getContext(), identifier);
691   }
692 
693   StructType idStructTy;
694 
695   if (!identifier.empty())
696     idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
697 
698   SmallVector<Type, 4> memberTypes;
699   SmallVector<StructType::OffsetInfo, 4> offsetInfo;
700   SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
701 
702   do {
703     Type memberType;
704     if (parser.parseType(memberType))
705       return Type();
706     memberTypes.push_back(memberType);
707 
708     if (succeeded(parser.parseOptionalLSquare()))
709       if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
710                                        memberDecorationInfo))
711         return Type();
712   } while (succeeded(parser.parseOptionalComma()));
713 
714   if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
715     parser.emitError(parser.getNameLoc(),
716                      "offset specification must be given for all members");
717     return Type();
718   }
719 
720   if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
721     return Type();
722 
723   if (!identifier.empty()) {
724     if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
725                                      memberDecorationInfo)))
726       return Type();
727     return idStructTy;
728   }
729 
730   return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
731 }
732 
733 // spirv-type ::= array-type
734 //              | element-type
735 //              | image-type
736 //              | pointer-type
737 //              | runtime-array-type
738 //              | sampled-image-type
739 //              | struct-type
740 Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
741   StringRef keyword;
742   if (parser.parseKeyword(&keyword))
743     return Type();
744 
745   if (keyword == "array")
746     return parseArrayType(*this, parser);
747   if (keyword == "coopmatrix")
748     return parseCooperativeMatrixType(*this, parser);
749   if (keyword == "image")
750     return parseImageType(*this, parser);
751   if (keyword == "ptr")
752     return parsePointerType(*this, parser);
753   if (keyword == "rtarray")
754     return parseRuntimeArrayType(*this, parser);
755   if (keyword == "sampled_image")
756     return parseSampledImageType(*this, parser);
757   if (keyword == "struct")
758     return parseStructType(*this, parser);
759   if (keyword == "matrix")
760     return parseMatrixType(*this, parser);
761   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
762   return Type();
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // Type Printing
767 //===----------------------------------------------------------------------===//
768 
769 static void print(ArrayType type, DialectAsmPrinter &os) {
770   os << "array<" << type.getNumElements() << " x " << type.getElementType();
771   if (unsigned stride = type.getArrayStride())
772     os << ", stride=" << stride;
773   os << ">";
774 }
775 
776 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
777   os << "rtarray<" << type.getElementType();
778   if (unsigned stride = type.getArrayStride())
779     os << ", stride=" << stride;
780   os << ">";
781 }
782 
783 static void print(PointerType type, DialectAsmPrinter &os) {
784   os << "ptr<" << type.getPointeeType() << ", "
785      << stringifyStorageClass(type.getStorageClass()) << ">";
786 }
787 
788 static void print(ImageType type, DialectAsmPrinter &os) {
789   os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
790      << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
791      << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
792      << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
793      << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
794      << stringifyImageFormat(type.getImageFormat()) << ">";
795 }
796 
797 static void print(SampledImageType type, DialectAsmPrinter &os) {
798   os << "sampled_image<" << type.getImageType() << ">";
799 }
800 
801 static void print(StructType type, DialectAsmPrinter &os) {
802   FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
803 
804   os << "struct<";
805 
806   if (type.isIdentified()) {
807     os << type.getIdentifier();
808 
809     cyclicPrint = os.tryStartCyclicPrint(type);
810     if (failed(cyclicPrint)) {
811       os << ">";
812       return;
813     }
814 
815     os << ", ";
816   }
817 
818   os << "(";
819 
820   auto printMember = [&](unsigned i) {
821     os << type.getElementType(i);
822     SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
823     type.getMemberDecorations(i, decorations);
824     if (type.hasOffset() || !decorations.empty()) {
825       os << " [";
826       if (type.hasOffset()) {
827         os << type.getMemberOffset(i);
828         if (!decorations.empty())
829           os << ", ";
830       }
831       auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
832         os << stringifyDecoration(decoration.decoration);
833         if (decoration.hasValue) {
834           os << "=" << decoration.decorationValue;
835         }
836       };
837       llvm::interleaveComma(decorations, os, eachFn);
838       os << "]";
839     }
840   };
841   llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
842                         printMember);
843   os << ")>";
844 }
845 
846 static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
847   os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
848      << type.getElementType() << ", " << type.getScope() << ", "
849      << type.getUse() << ">";
850 }
851 
852 static void print(MatrixType type, DialectAsmPrinter &os) {
853   os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
854   os << ">";
855 }
856 
857 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
858   TypeSwitch<Type>(type)
859       .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
860             ImageType, SampledImageType, StructType, MatrixType>(
861           [&](auto type) { print(type, os); })
862       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
863 }
864 
865 //===----------------------------------------------------------------------===//
866 // Constant
867 //===----------------------------------------------------------------------===//
868 
869 Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
870                                              Attribute value, Type type,
871                                              Location loc) {
872   if (auto poison = dyn_cast<ub::PoisonAttr>(value))
873     return builder.create<ub::PoisonOp>(loc, type, poison);
874 
875   if (!spirv::ConstantOp::isBuildableWith(type))
876     return nullptr;
877 
878   return builder.create<spirv::ConstantOp>(loc, type, value);
879 }
880 
881 //===----------------------------------------------------------------------===//
882 // Shader Interface ABI
883 //===----------------------------------------------------------------------===//
884 
885 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
886                                                      NamedAttribute attribute) {
887   StringRef symbol = attribute.getName().strref();
888   Attribute attr = attribute.getValue();
889 
890   if (symbol == spirv::getEntryPointABIAttrName()) {
891     if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
892       return op->emitError("'")
893              << symbol << "' attribute must be an entry point ABI attribute";
894     }
895   } else if (symbol == spirv::getTargetEnvAttrName()) {
896     if (!llvm::isa<spirv::TargetEnvAttr>(attr))
897       return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
898   } else {
899     return op->emitError("found unsupported '")
900            << symbol << "' attribute on operation";
901   }
902 
903   return success();
904 }
905 
906 /// Verifies the given SPIR-V `attribute` attached to a value of the given
907 /// `valueType` is valid.
908 static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
909                                            NamedAttribute attribute) {
910   StringRef symbol = attribute.getName().strref();
911   Attribute attr = attribute.getValue();
912 
913   if (symbol == spirv::getInterfaceVarABIAttrName()) {
914     auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
915     if (!varABIAttr)
916       return emitError(loc, "'")
917              << symbol << "' must be a spirv::InterfaceVarABIAttr";
918 
919     if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
920       return emitError(loc, "'") << symbol
921                                  << "' attribute cannot specify storage class "
922                                     "when attaching to a non-scalar value";
923     return success();
924   }
925   if (symbol == spirv::DecorationAttr::name) {
926     if (!isa<spirv::DecorationAttr>(attr))
927       return emitError(loc, "'")
928              << symbol << "' must be a spirv::DecorationAttr";
929     return success();
930   }
931 
932   return emitError(loc, "found unsupported '")
933          << symbol << "' attribute on region argument";
934 }
935 
936 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
937                                                      unsigned regionIndex,
938                                                      unsigned argIndex,
939                                                      NamedAttribute attribute) {
940   auto funcOp = dyn_cast<FunctionOpInterface>(op);
941   if (!funcOp)
942     return success();
943   Type argType = funcOp.getArgumentTypes()[argIndex];
944 
945   return verifyRegionAttribute(op->getLoc(), argType, attribute);
946 }
947 
948 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
949     Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
950     NamedAttribute attribute) {
951   return op->emitError("cannot attach SPIR-V attributes to region result");
952 }
953