xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp (revision 7359a6b7996f92e6659418d3d2e5b57c44d65e37)
1 //===- SPIRVAttributes.cpp - SPIR-V attribute definitions -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::spirv;
18 
19 //===----------------------------------------------------------------------===//
20 // TableGen'erated attribute utility functions
21 //===----------------------------------------------------------------------===//
22 
23 namespace mlir {
24 namespace spirv {
25 #include "mlir/Dialect/SPIRV/IR/SPIRVAttrUtils.inc"
26 } // namespace spirv
27 
28 //===----------------------------------------------------------------------===//
29 // Attribute storage classes
30 //===----------------------------------------------------------------------===//
31 
32 namespace spirv {
33 namespace detail {
34 
35 struct InterfaceVarABIAttributeStorage : public AttributeStorage {
36   using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
37 
38   InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding,
39                                   Attribute storageClass)
40       : descriptorSet(descriptorSet), binding(binding),
41         storageClass(storageClass) {}
42 
43   bool operator==(const KeyTy &key) const {
44     return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding &&
45            std::get<2>(key) == storageClass;
46   }
47 
48   static InterfaceVarABIAttributeStorage *
49   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
50     return new (allocator.allocate<InterfaceVarABIAttributeStorage>())
51         InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key),
52                                         std::get<2>(key));
53   }
54 
55   Attribute descriptorSet;
56   Attribute binding;
57   Attribute storageClass;
58 };
59 
60 struct VerCapExtAttributeStorage : public AttributeStorage {
61   using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
62 
63   VerCapExtAttributeStorage(Attribute version, Attribute capabilities,
64                             Attribute extensions)
65       : version(version), capabilities(capabilities), extensions(extensions) {}
66 
67   bool operator==(const KeyTy &key) const {
68     return std::get<0>(key) == version && std::get<1>(key) == capabilities &&
69            std::get<2>(key) == extensions;
70   }
71 
72   static VerCapExtAttributeStorage *
73   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
74     return new (allocator.allocate<VerCapExtAttributeStorage>())
75         VerCapExtAttributeStorage(std::get<0>(key), std::get<1>(key),
76                                   std::get<2>(key));
77   }
78 
79   Attribute version;
80   Attribute capabilities;
81   Attribute extensions;
82 };
83 
84 struct TargetEnvAttributeStorage : public AttributeStorage {
85   using KeyTy =
86       std::tuple<Attribute, ClientAPI, Vendor, DeviceType, uint32_t, Attribute>;
87 
88   TargetEnvAttributeStorage(Attribute triple, ClientAPI clientAPI,
89                             Vendor vendorID, DeviceType deviceType,
90                             uint32_t deviceID, Attribute limits)
91       : triple(triple), limits(limits), clientAPI(clientAPI),
92         vendorID(vendorID), deviceType(deviceType), deviceID(deviceID) {}
93 
94   bool operator==(const KeyTy &key) const {
95     return key == std::make_tuple(triple, clientAPI, vendorID, deviceType,
96                                   deviceID, limits);
97   }
98 
99   static TargetEnvAttributeStorage *
100   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
101     return new (allocator.allocate<TargetEnvAttributeStorage>())
102         TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key),
103                                   std::get<2>(key), std::get<3>(key),
104                                   std::get<4>(key), std::get<5>(key));
105   }
106 
107   Attribute triple;
108   Attribute limits;
109   ClientAPI clientAPI;
110   Vendor vendorID;
111   DeviceType deviceType;
112   uint32_t deviceID;
113 };
114 } // namespace detail
115 } // namespace spirv
116 } // namespace mlir
117 
118 //===----------------------------------------------------------------------===//
119 // InterfaceVarABIAttr
120 //===----------------------------------------------------------------------===//
121 
122 spirv::InterfaceVarABIAttr
123 spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding,
124                                 std::optional<spirv::StorageClass> storageClass,
125                                 MLIRContext *context) {
126   Builder b(context);
127   auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet);
128   auto bindingAttr = b.getI32IntegerAttr(binding);
129   auto storageClassAttr =
130       storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass))
131                    : IntegerAttr();
132   return get(descriptorSetAttr, bindingAttr, storageClassAttr);
133 }
134 
135 spirv::InterfaceVarABIAttr
136 spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
137                                 IntegerAttr storageClass) {
138   assert(descriptorSet && binding);
139   MLIRContext *context = descriptorSet.getContext();
140   return Base::get(context, descriptorSet, binding, storageClass);
141 }
142 
143 StringRef spirv::InterfaceVarABIAttr::getKindName() {
144   return "interface_var_abi";
145 }
146 
147 uint32_t spirv::InterfaceVarABIAttr::getBinding() {
148   return llvm::cast<IntegerAttr>(getImpl()->binding).getInt();
149 }
150 
151 uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
152   return llvm::cast<IntegerAttr>(getImpl()->descriptorSet).getInt();
153 }
154 
155 std::optional<spirv::StorageClass>
156 spirv::InterfaceVarABIAttr::getStorageClass() {
157   if (getImpl()->storageClass)
158     return static_cast<spirv::StorageClass>(
159         llvm::cast<IntegerAttr>(getImpl()->storageClass)
160             .getValue()
161             .getZExtValue());
162   return std::nullopt;
163 }
164 
165 LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants(
166     function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
167     IntegerAttr binding, IntegerAttr storageClass) {
168   if (!descriptorSet.getType().isSignlessInteger(32))
169     return emitError() << "expected 32-bit integer for descriptor set";
170 
171   if (!binding.getType().isSignlessInteger(32))
172     return emitError() << "expected 32-bit integer for binding";
173 
174   if (storageClass) {
175     if (auto storageClassAttr = llvm::cast<IntegerAttr>(storageClass)) {
176       auto storageClassValue =
177           spirv::symbolizeStorageClass(storageClassAttr.getInt());
178       if (!storageClassValue)
179         return emitError() << "unknown storage class";
180     } else {
181       return emitError() << "expected valid storage class";
182     }
183   }
184 
185   return success();
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // VerCapExtAttr
190 //===----------------------------------------------------------------------===//
191 
192 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(
193     spirv::Version version, ArrayRef<spirv::Capability> capabilities,
194     ArrayRef<spirv::Extension> extensions, MLIRContext *context) {
195   Builder b(context);
196 
197   auto versionAttr = b.getI32IntegerAttr(static_cast<uint32_t>(version));
198 
199   SmallVector<Attribute, 4> capAttrs;
200   capAttrs.reserve(capabilities.size());
201   for (spirv::Capability cap : capabilities)
202     capAttrs.push_back(b.getI32IntegerAttr(static_cast<uint32_t>(cap)));
203 
204   SmallVector<Attribute, 4> extAttrs;
205   extAttrs.reserve(extensions.size());
206   for (spirv::Extension ext : extensions)
207     extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext)));
208 
209   return get(versionAttr, b.getArrayAttr(capAttrs), b.getArrayAttr(extAttrs));
210 }
211 
212 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version,
213                                                ArrayAttr capabilities,
214                                                ArrayAttr extensions) {
215   assert(version && capabilities && extensions);
216   MLIRContext *context = version.getContext();
217   return Base::get(context, version, capabilities, extensions);
218 }
219 
220 StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
221 
222 spirv::Version spirv::VerCapExtAttr::getVersion() {
223   return static_cast<spirv::Version>(
224       llvm::cast<IntegerAttr>(getImpl()->version).getValue().getZExtValue());
225 }
226 
227 spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it)
228     : llvm::mapped_iterator<ArrayAttr::iterator,
229                             spirv::Extension (*)(Attribute)>(
230           it, [](Attribute attr) {
231             return *symbolizeExtension(llvm::cast<StringAttr>(attr).getValue());
232           }) {}
233 
234 spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
235   auto range = getExtensionsAttr().getValue();
236   return {ext_iterator(range.begin()), ext_iterator(range.end())};
237 }
238 
239 ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() {
240   return llvm::cast<ArrayAttr>(getImpl()->extensions);
241 }
242 
243 spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
244     : llvm::mapped_iterator<ArrayAttr::iterator,
245                             spirv::Capability (*)(Attribute)>(
246           it, [](Attribute attr) {
247             return *symbolizeCapability(
248                 llvm::cast<IntegerAttr>(attr).getValue().getZExtValue());
249           }) {}
250 
251 spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
252   auto range = getCapabilitiesAttr().getValue();
253   return {cap_iterator(range.begin()), cap_iterator(range.end())};
254 }
255 
256 ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
257   return llvm::cast<ArrayAttr>(getImpl()->capabilities);
258 }
259 
260 LogicalResult spirv::VerCapExtAttr::verifyInvariants(
261     function_ref<InFlightDiagnostic()> emitError, IntegerAttr version,
262     ArrayAttr capabilities, ArrayAttr extensions) {
263   if (!version.getType().isSignlessInteger(32))
264     return emitError() << "expected 32-bit integer for version";
265 
266   if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
267         if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
268           if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))
269             return true;
270         return false;
271       }))
272     return emitError() << "unknown capability in capability list";
273 
274   if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
275         if (auto strAttr = llvm::dyn_cast<StringAttr>(attr))
276           if (spirv::symbolizeExtension(strAttr.getValue()))
277             return true;
278         return false;
279       }))
280     return emitError() << "unknown extension in extension list";
281 
282   return success();
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // TargetEnvAttr
287 //===----------------------------------------------------------------------===//
288 
289 spirv::TargetEnvAttr spirv::TargetEnvAttr::get(
290     spirv::VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI,
291     Vendor vendorID, DeviceType deviceType, uint32_t deviceID) {
292   assert(triple && limits && "expected valid triple and limits");
293   MLIRContext *context = triple.getContext();
294   return Base::get(context, triple, clientAPI, vendorID, deviceType, deviceID,
295                    limits);
296 }
297 
298 StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
299 
300 spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const {
301   return llvm::cast<spirv::VerCapExtAttr>(getImpl()->triple);
302 }
303 
304 spirv::Version spirv::TargetEnvAttr::getVersion() const {
305   return getTripleAttr().getVersion();
306 }
307 
308 spirv::VerCapExtAttr::ext_range spirv::TargetEnvAttr::getExtensions() {
309   return getTripleAttr().getExtensions();
310 }
311 
312 ArrayAttr spirv::TargetEnvAttr::getExtensionsAttr() {
313   return getTripleAttr().getExtensionsAttr();
314 }
315 
316 spirv::VerCapExtAttr::cap_range spirv::TargetEnvAttr::getCapabilities() {
317   return getTripleAttr().getCapabilities();
318 }
319 
320 ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() {
321   return getTripleAttr().getCapabilitiesAttr();
322 }
323 
324 spirv::ClientAPI spirv::TargetEnvAttr::getClientAPI() const {
325   return getImpl()->clientAPI;
326 }
327 
328 spirv::Vendor spirv::TargetEnvAttr::getVendorID() const {
329   return getImpl()->vendorID;
330 }
331 
332 spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() const {
333   return getImpl()->deviceType;
334 }
335 
336 uint32_t spirv::TargetEnvAttr::getDeviceID() const {
337   return getImpl()->deviceID;
338 }
339 
340 spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
341   return llvm::cast<spirv::ResourceLimitsAttr>(getImpl()->limits);
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // ODS Generated Attributes
346 //===----------------------------------------------------------------------===//
347 
348 #define GET_ATTRDEF_CLASSES
349 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"
350 
351 //===----------------------------------------------------------------------===//
352 // Attribute Parsing
353 //===----------------------------------------------------------------------===//
354 
355 /// Parses a comma-separated list of keywords, invokes `processKeyword` on each
356 /// of the parsed keyword, and returns failure if any error occurs.
357 static ParseResult
358 parseKeywordList(DialectAsmParser &parser,
359                  function_ref<LogicalResult(SMLoc, StringRef)> processKeyword) {
360   if (parser.parseLSquare())
361     return failure();
362 
363   // Special case for empty list.
364   if (succeeded(parser.parseOptionalRSquare()))
365     return success();
366 
367   // Keep parsing the keyword and an optional comma following it. If the comma
368   // is successfully parsed, then we have more keywords to parse.
369   if (failed(parser.parseCommaSeparatedList([&]() {
370         auto loc = parser.getCurrentLocation();
371         StringRef keyword;
372         if (parser.parseKeyword(&keyword) ||
373             failed(processKeyword(loc, keyword)))
374           return failure();
375         return success();
376       })))
377     return failure();
378   return parser.parseRSquare();
379 }
380 
381 /// Parses a spirv::InterfaceVarABIAttr.
382 static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) {
383   if (parser.parseLess())
384     return {};
385 
386   Builder &builder = parser.getBuilder();
387 
388   if (parser.parseLParen())
389     return {};
390 
391   IntegerAttr descriptorSetAttr;
392   {
393     auto loc = parser.getCurrentLocation();
394     uint32_t descriptorSet = 0;
395     auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet);
396 
397     if (!descriptorSetParseResult.has_value() ||
398         failed(*descriptorSetParseResult)) {
399       parser.emitError(loc, "missing descriptor set");
400       return {};
401     }
402     descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet);
403   }
404 
405   if (parser.parseComma())
406     return {};
407 
408   IntegerAttr bindingAttr;
409   {
410     auto loc = parser.getCurrentLocation();
411     uint32_t binding = 0;
412     auto bindingParseResult = parser.parseOptionalInteger(binding);
413 
414     if (!bindingParseResult.has_value() || failed(*bindingParseResult)) {
415       parser.emitError(loc, "missing binding");
416       return {};
417     }
418     bindingAttr = builder.getI32IntegerAttr(binding);
419   }
420 
421   if (parser.parseRParen())
422     return {};
423 
424   IntegerAttr storageClassAttr;
425   {
426     if (succeeded(parser.parseOptionalComma())) {
427       auto loc = parser.getCurrentLocation();
428       StringRef storageClass;
429       if (parser.parseKeyword(&storageClass))
430         return {};
431 
432       if (auto storageClassSymbol =
433               spirv::symbolizeStorageClass(storageClass)) {
434         storageClassAttr = builder.getI32IntegerAttr(
435             static_cast<uint32_t>(*storageClassSymbol));
436       } else {
437         parser.emitError(loc, "unknown storage class: ") << storageClass;
438         return {};
439       }
440     }
441   }
442 
443   if (parser.parseGreater())
444     return {};
445 
446   return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr,
447                                          storageClassAttr);
448 }
449 
450 static Attribute parseVerCapExtAttr(DialectAsmParser &parser) {
451   if (parser.parseLess())
452     return {};
453 
454   Builder &builder = parser.getBuilder();
455 
456   IntegerAttr versionAttr;
457   {
458     auto loc = parser.getCurrentLocation();
459     StringRef version;
460     if (parser.parseKeyword(&version) || parser.parseComma())
461       return {};
462 
463     if (auto versionSymbol = spirv::symbolizeVersion(version)) {
464       versionAttr =
465           builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol));
466     } else {
467       parser.emitError(loc, "unknown version: ") << version;
468       return {};
469     }
470   }
471 
472   ArrayAttr capabilitiesAttr;
473   {
474     SmallVector<Attribute, 4> capabilities;
475     SMLoc errorloc;
476     StringRef errorKeyword;
477 
478     auto processCapability = [&](SMLoc loc, StringRef capability) {
479       if (auto capSymbol = spirv::symbolizeCapability(capability)) {
480         capabilities.push_back(
481             builder.getI32IntegerAttr(static_cast<uint32_t>(*capSymbol)));
482         return success();
483       }
484       return errorloc = loc, errorKeyword = capability, failure();
485     };
486     if (parseKeywordList(parser, processCapability) || parser.parseComma()) {
487       if (!errorKeyword.empty())
488         parser.emitError(errorloc, "unknown capability: ") << errorKeyword;
489       return {};
490     }
491 
492     capabilitiesAttr = builder.getArrayAttr(capabilities);
493   }
494 
495   ArrayAttr extensionsAttr;
496   {
497     SmallVector<Attribute, 1> extensions;
498     SMLoc errorloc;
499     StringRef errorKeyword;
500 
501     auto processExtension = [&](SMLoc loc, StringRef extension) {
502       if (spirv::symbolizeExtension(extension)) {
503         extensions.push_back(builder.getStringAttr(extension));
504         return success();
505       }
506       return errorloc = loc, errorKeyword = extension, failure();
507     };
508     if (parseKeywordList(parser, processExtension)) {
509       if (!errorKeyword.empty())
510         parser.emitError(errorloc, "unknown extension: ") << errorKeyword;
511       return {};
512     }
513 
514     extensionsAttr = builder.getArrayAttr(extensions);
515   }
516 
517   if (parser.parseGreater())
518     return {};
519 
520   return spirv::VerCapExtAttr::get(versionAttr, capabilitiesAttr,
521                                    extensionsAttr);
522 }
523 
524 /// Parses a spirv::TargetEnvAttr.
525 static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
526   if (parser.parseLess())
527     return {};
528 
529   spirv::VerCapExtAttr tripleAttr;
530   if (parser.parseAttribute(tripleAttr) || parser.parseComma())
531     return {};
532 
533   auto clientAPI = spirv::ClientAPI::Unknown;
534   if (succeeded(parser.parseOptionalKeyword("api"))) {
535     if (parser.parseEqual())
536       return {};
537     auto loc = parser.getCurrentLocation();
538     StringRef apiStr;
539     if (parser.parseKeyword(&apiStr))
540       return {};
541     if (auto apiSymbol = spirv::symbolizeClientAPI(apiStr))
542       clientAPI = *apiSymbol;
543     else
544       parser.emitError(loc, "unknown client API: ") << apiStr;
545     if (parser.parseComma())
546       return {};
547   }
548 
549   // Parse [vendor[:device-type[:device-id]]]
550   Vendor vendorID = Vendor::Unknown;
551   DeviceType deviceType = DeviceType::Unknown;
552   uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID;
553   {
554     auto loc = parser.getCurrentLocation();
555     StringRef vendorStr;
556     if (succeeded(parser.parseOptionalKeyword(&vendorStr))) {
557       if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr))
558         vendorID = *vendorSymbol;
559       else
560         parser.emitError(loc, "unknown vendor: ") << vendorStr;
561 
562       if (succeeded(parser.parseOptionalColon())) {
563         loc = parser.getCurrentLocation();
564         StringRef deviceTypeStr;
565         if (parser.parseKeyword(&deviceTypeStr))
566           return {};
567         if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr))
568           deviceType = *deviceTypeSymbol;
569         else
570           parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
571 
572         if (succeeded(parser.parseOptionalColon())) {
573           loc = parser.getCurrentLocation();
574           if (parser.parseInteger(deviceID))
575             return {};
576         }
577       }
578       if (parser.parseComma())
579         return {};
580     }
581   }
582 
583   ResourceLimitsAttr limitsAttr;
584   if (parser.parseAttribute(limitsAttr) || parser.parseGreater())
585     return {};
586 
587   return spirv::TargetEnvAttr::get(tripleAttr, limitsAttr, clientAPI, vendorID,
588                                    deviceType, deviceID);
589 }
590 
591 Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
592                                        Type type) const {
593   // SPIR-V attributes are dictionaries so they do not have type.
594   if (type) {
595     parser.emitError(parser.getNameLoc(), "unexpected type");
596     return {};
597   }
598 
599   // Parse the kind keyword first.
600   StringRef attrKind;
601   Attribute attr;
602   OptionalParseResult result =
603       generatedAttributeParser(parser, &attrKind, type, attr);
604   if (result.has_value())
605     return attr;
606 
607   if (attrKind == spirv::TargetEnvAttr::getKindName())
608     return parseTargetEnvAttr(parser);
609   if (attrKind == spirv::VerCapExtAttr::getKindName())
610     return parseVerCapExtAttr(parser);
611   if (attrKind == spirv::InterfaceVarABIAttr::getKindName())
612     return parseInterfaceVarABIAttr(parser);
613 
614   parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ")
615       << attrKind;
616   return {};
617 }
618 
619 //===----------------------------------------------------------------------===//
620 // Attribute Printing
621 //===----------------------------------------------------------------------===//
622 
623 static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
624   auto &os = printer.getStream();
625   printer << spirv::VerCapExtAttr::getKindName() << "<"
626           << spirv::stringifyVersion(triple.getVersion()) << ", [";
627   llvm::interleaveComma(
628       triple.getCapabilities(), os,
629       [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); });
630   printer << "], [";
631   llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) {
632     os << llvm::cast<StringAttr>(attr).getValue();
633   });
634   printer << "]>";
635 }
636 
637 static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
638   printer << spirv::TargetEnvAttr::getKindName() << "<#spirv.";
639   print(targetEnv.getTripleAttr(), printer);
640   auto clientAPI = targetEnv.getClientAPI();
641   if (clientAPI != spirv::ClientAPI::Unknown)
642     printer << ", api=" << clientAPI;
643   spirv::Vendor vendorID = targetEnv.getVendorID();
644   spirv::DeviceType deviceType = targetEnv.getDeviceType();
645   uint32_t deviceID = targetEnv.getDeviceID();
646   if (vendorID != spirv::Vendor::Unknown) {
647     printer << ", " << spirv::stringifyVendor(vendorID);
648     if (deviceType != spirv::DeviceType::Unknown) {
649       printer << ":" << spirv::stringifyDeviceType(deviceType);
650       if (deviceID != spirv::TargetEnvAttr::kUnknownDeviceID)
651         printer << ":" << deviceID;
652     }
653   }
654   printer << ", " << targetEnv.getResourceLimits() << ">";
655 }
656 
657 static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,
658                   DialectAsmPrinter &printer) {
659   printer << spirv::InterfaceVarABIAttr::getKindName() << "<("
660           << interfaceVarABIAttr.getDescriptorSet() << ", "
661           << interfaceVarABIAttr.getBinding() << ")";
662   auto storageClass = interfaceVarABIAttr.getStorageClass();
663   if (storageClass)
664     printer << ", " << spirv::stringifyStorageClass(*storageClass);
665   printer << ">";
666 }
667 
668 void SPIRVDialect::printAttribute(Attribute attr,
669                                   DialectAsmPrinter &printer) const {
670   if (succeeded(generatedAttributePrinter(attr, printer)))
671     return;
672 
673   if (auto targetEnv = llvm::dyn_cast<TargetEnvAttr>(attr))
674     print(targetEnv, printer);
675   else if (auto vceAttr = llvm::dyn_cast<VerCapExtAttr>(attr))
676     print(vceAttr, printer);
677   else if (auto interfaceVarABIAttr = llvm::dyn_cast<InterfaceVarABIAttr>(attr))
678     print(interfaceVarABIAttr, printer);
679   else
680     llvm_unreachable("unhandled SPIR-V attribute kind");
681 }
682 
683 //===----------------------------------------------------------------------===//
684 // SPIR-V Dialect
685 //===----------------------------------------------------------------------===//
686 
687 void spirv::SPIRVDialect::registerAttributes() {
688   addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
689   addAttributes<
690 #define GET_ATTRDEF_LIST
691 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"
692       >();
693 }
694