xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (revision a01097faca35a9a8927c8b0c514bc35dcebec00f)
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
18 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
20 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
21 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
22 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
23 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/OpImplementation.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/TypeUtilities.h"
31 #include "mlir/Interfaces/FunctionImplementation.h"
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include <cassert>
39 #include <numeric>
40 #include <optional>
41 #include <type_traits>
42 
43 using namespace mlir;
44 using namespace mlir::spirv::AttrNames;
45 
46 //===----------------------------------------------------------------------===//
47 // Common utility functions
48 //===----------------------------------------------------------------------===//
49 
50 LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
51   auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
52   if (!constOp) {
53     return failure();
54   }
55   auto valueAttr = constOp.getValue();
56   auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
57   if (!integerValueAttr) {
58     return failure();
59   }
60 
61   if (integerValueAttr.getType().isSignlessInteger())
62     value = integerValueAttr.getInt();
63   else
64     value = integerValueAttr.getSInt();
65 
66   return success();
67 }
68 
69 LogicalResult
70 spirv::verifyMemorySemantics(Operation *op,
71                              spirv::MemorySemantics memorySemantics) {
72   // According to the SPIR-V specification:
73   // "Despite being a mask and allowing multiple bits to be combined, it is
74   // invalid for more than one of these four bits to be set: Acquire, Release,
75   // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
76   // Release semantics is done by setting the AcquireRelease bit, not by setting
77   // two bits."
78   auto atMostOneInSet = spirv::MemorySemantics::Acquire |
79                         spirv::MemorySemantics::Release |
80                         spirv::MemorySemantics::AcquireRelease |
81                         spirv::MemorySemantics::SequentiallyConsistent;
82 
83   auto bitCount =
84       llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
85   if (bitCount > 1) {
86     return op->emitError(
87         "expected at most one of these four memory constraints "
88         "to be set: `Acquire`, `Release`,"
89         "`AcquireRelease` or `SequentiallyConsistent`");
90   }
91   return success();
92 }
93 
94 void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer,
95                                      SmallVectorImpl<StringRef> &elidedAttrs) {
96   // Print optional descriptor binding
97   auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
98       stringifyDecoration(spirv::Decoration::DescriptorSet));
99   auto bindingName = llvm::convertToSnakeFromCamelCase(
100       stringifyDecoration(spirv::Decoration::Binding));
101   auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
102   auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
103   if (descriptorSet && binding) {
104     elidedAttrs.push_back(descriptorSetName);
105     elidedAttrs.push_back(bindingName);
106     printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
107             << ")";
108   }
109 
110   // Print BuiltIn attribute if present
111   auto builtInName = llvm::convertToSnakeFromCamelCase(
112       stringifyDecoration(spirv::Decoration::BuiltIn));
113   if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
114     printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
115     elidedAttrs.push_back(builtInName);
116   }
117 
118   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
119 }
120 
121 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
122                                                    OperationState &result) {
123   SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
124   Type type;
125   // If the operand list is in-between parentheses, then we have a generic form.
126   // (see the fallback in `printOneResultOp`).
127   SMLoc loc = parser.getCurrentLocation();
128   if (!parser.parseOptionalLParen()) {
129     if (parser.parseOperandList(ops) || parser.parseRParen() ||
130         parser.parseOptionalAttrDict(result.attributes) ||
131         parser.parseColon() || parser.parseType(type))
132       return failure();
133     auto fnType = llvm::dyn_cast<FunctionType>(type);
134     if (!fnType) {
135       parser.emitError(loc, "expected function type");
136       return failure();
137     }
138     if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
139       return failure();
140     result.addTypes(fnType.getResults());
141     return success();
142   }
143   return failure(parser.parseOperandList(ops) ||
144                  parser.parseOptionalAttrDict(result.attributes) ||
145                  parser.parseColonType(type) ||
146                  parser.resolveOperands(ops, type, result.operands) ||
147                  parser.addTypeToList(type, result.types));
148 }
149 
150 static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
151   assert(op->getNumResults() == 1 && "op should have one result");
152 
153   // If not all the operand and result types are the same, just use the
154   // generic assembly form to avoid omitting information in printing.
155   auto resultType = op->getResult(0).getType();
156   if (llvm::any_of(op->getOperandTypes(),
157                    [&](Type type) { return type != resultType; })) {
158     p.printGenericOp(op, /*printOpName=*/false);
159     return;
160   }
161 
162   p << ' ';
163   p.printOperands(op->getOperands());
164   p.printOptionalAttrDict(op->getAttrs());
165   // Now we can output only one type for all operands and the result.
166   p << " : " << resultType;
167 }
168 
169 template <typename Op>
170 static LogicalResult verifyImageOperands(Op imageOp,
171                                          spirv::ImageOperandsAttr attr,
172                                          Operation::operand_range operands) {
173   if (!attr) {
174     if (operands.empty())
175       return success();
176 
177     return imageOp.emitError("the Image Operands should encode what operands "
178                              "follow, as per Image Operands");
179   }
180 
181   // TODO: Add the validation rules for the following Image Operands.
182   spirv::ImageOperands noSupportOperands =
183       spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
184       spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
185       spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
186       spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
187       spirv::ImageOperands::MakeTexelAvailable |
188       spirv::ImageOperands::MakeTexelVisible |
189       spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
190 
191   if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
192     llvm_unreachable("unimplemented operands of Image Operands");
193 
194   return success();
195 }
196 
197 template <typename BlockReadWriteOpTy>
198 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
199                                                         Value ptr, Value val) {
200   auto valType = val.getType();
201   if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
202     valType = valVecTy.getElementType();
203 
204   if (valType !=
205       llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
206     return op.emitOpError("mismatch in result type and pointer type");
207   }
208   return success();
209 }
210 
211 /// Walks the given type hierarchy with the given indices, potentially down
212 /// to component granularity, to select an element type. Returns null type and
213 /// emits errors with the given loc on failure.
214 static Type
215 getElementType(Type type, ArrayRef<int32_t> indices,
216                function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
217   if (indices.empty()) {
218     emitErrorFn("expected at least one index for spirv.CompositeExtract");
219     return nullptr;
220   }
221 
222   for (auto index : indices) {
223     if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
224       if (cType.hasCompileTimeKnownNumElements() &&
225           (index < 0 ||
226            static_cast<uint64_t>(index) >= cType.getNumElements())) {
227         emitErrorFn("index ") << index << " out of bounds for " << type;
228         return nullptr;
229       }
230       type = cType.getElementType(index);
231     } else {
232       emitErrorFn("cannot extract from non-composite type ")
233           << type << " with index " << index;
234       return nullptr;
235     }
236   }
237   return type;
238 }
239 
240 static Type
241 getElementType(Type type, Attribute indices,
242                function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
243   auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
244   if (!indicesArrayAttr) {
245     emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
246     return nullptr;
247   }
248   if (indicesArrayAttr.empty()) {
249     emitErrorFn("expected at least one index for spirv.CompositeExtract");
250     return nullptr;
251   }
252 
253   SmallVector<int32_t, 2> indexVals;
254   for (auto indexAttr : indicesArrayAttr) {
255     auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
256     if (!indexIntAttr) {
257       emitErrorFn("expected an 32-bit integer for index, but found '")
258           << indexAttr << "'";
259       return nullptr;
260     }
261     indexVals.push_back(indexIntAttr.getInt());
262   }
263   return getElementType(type, indexVals, emitErrorFn);
264 }
265 
266 static Type getElementType(Type type, Attribute indices, Location loc) {
267   auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
268     return ::mlir::emitError(loc, err);
269   };
270   return getElementType(type, indices, errorFn);
271 }
272 
273 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
274                            SMLoc loc) {
275   auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
276     return parser.emitError(loc, err);
277   };
278   return getElementType(type, indices, errorFn);
279 }
280 
281 template <typename ExtendedBinaryOp>
282 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
283   auto resultType = llvm::cast<spirv::StructType>(op.getType());
284   if (resultType.getNumElements() != 2)
285     return op.emitOpError("expected result struct type containing two members");
286 
287   if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
288                         resultType.getElementType(0),
289                         resultType.getElementType(1)}))
290     return op.emitOpError(
291         "expected all operand types and struct member types are the same");
292 
293   return success();
294 }
295 
296 static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
297                                                    OperationState &result) {
298   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
299   if (parser.parseOptionalAttrDict(result.attributes) ||
300       parser.parseOperandList(operands) || parser.parseColon())
301     return failure();
302 
303   Type resultType;
304   SMLoc loc = parser.getCurrentLocation();
305   if (parser.parseType(resultType))
306     return failure();
307 
308   auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
309   if (!structType || structType.getNumElements() != 2)
310     return parser.emitError(loc, "expected spirv.struct type with two members");
311 
312   SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
313   if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
314     return failure();
315 
316   result.addTypes(resultType);
317   return success();
318 }
319 
320 static void printArithmeticExtendedBinaryOp(Operation *op,
321                                             OpAsmPrinter &printer) {
322   printer << ' ';
323   printer.printOptionalAttrDict(op->getAttrs());
324   printer.printOperands(op->getOperands());
325   printer << " : " << op->getResultTypes().front();
326 }
327 
328 static LogicalResult verifyShiftOp(Operation *op) {
329   if (op->getOperand(0).getType() != op->getResult(0).getType()) {
330     return op->emitError("expected the same type for the first operand and "
331                          "result, but provided ")
332            << op->getOperand(0).getType() << " and "
333            << op->getResult(0).getType();
334   }
335   return success();
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // spirv.mlir.addressof
340 //===----------------------------------------------------------------------===//
341 
342 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
343                                spirv::GlobalVariableOp var) {
344   build(builder, state, var.getType(), SymbolRefAttr::get(var));
345 }
346 
347 LogicalResult spirv::AddressOfOp::verify() {
348   auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
349       SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
350                                            getVariableAttr()));
351   if (!varOp) {
352     return emitOpError("expected spirv.GlobalVariable symbol");
353   }
354   if (getPointer().getType() != varOp.getType()) {
355     return emitOpError(
356         "result type mismatch with the referenced global variable's type");
357   }
358   return success();
359 }
360 
361 //===----------------------------------------------------------------------===//
362 // spirv.CompositeConstruct
363 //===----------------------------------------------------------------------===//
364 
365 LogicalResult spirv::CompositeConstructOp::verify() {
366   operand_range constituents = this->getConstituents();
367 
368   // There are 4 cases with varying verification rules:
369   // 1. Cooperative Matrices (1 constituent)
370   // 2. Structs (1 constituent for each member)
371   // 3. Arrays (1 constituent for each array element)
372   // 4. Vectors (1 constituent (sub-)element for each vector element)
373 
374   auto coopElementType =
375       llvm::TypeSwitch<Type, Type>(getType())
376           .Case<spirv::CooperativeMatrixType>(
377               [](auto coopType) { return coopType.getElementType(); })
378           .Default([](Type) { return nullptr; });
379 
380   // Case 1. -- matrices.
381   if (coopElementType) {
382     if (constituents.size() != 1)
383       return emitOpError("has incorrect number of operands: expected ")
384              << "1, but provided " << constituents.size();
385     if (coopElementType != constituents.front().getType())
386       return emitOpError("operand type mismatch: expected operand type ")
387              << coopElementType << ", but provided "
388              << constituents.front().getType();
389     return success();
390   }
391 
392   // Case 2./3./4. -- number of constituents matches the number of elements.
393   auto cType = llvm::cast<spirv::CompositeType>(getType());
394   if (constituents.size() == cType.getNumElements()) {
395     for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
396       if (constituents[index].getType() != cType.getElementType(index)) {
397         return emitOpError("operand type mismatch: expected operand type ")
398                << cType.getElementType(index) << ", but provided "
399                << constituents[index].getType();
400       }
401     }
402     return success();
403   }
404 
405   // Case 4. -- check that all constituents add up tp the expected vector type.
406   auto resultType = llvm::dyn_cast<VectorType>(cType);
407   if (!resultType)
408     return emitOpError(
409         "expected to return a vector or cooperative matrix when the number of "
410         "constituents is less than what the result needs");
411 
412   SmallVector<unsigned> sizes;
413   for (Value component : constituents) {
414     if (!llvm::isa<VectorType>(component.getType()) &&
415         !component.getType().isIntOrFloat())
416       return emitOpError("operand type mismatch: expected operand to have "
417                          "a scalar or vector type, but provided ")
418              << component.getType();
419 
420     Type elementType = component.getType();
421     if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
422       sizes.push_back(vectorType.getNumElements());
423       elementType = vectorType.getElementType();
424     } else {
425       sizes.push_back(1);
426     }
427 
428     if (elementType != resultType.getElementType())
429       return emitOpError("operand element type mismatch: expected to be ")
430              << resultType.getElementType() << ", but provided " << elementType;
431   }
432   unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
433   if (totalCount != cType.getNumElements())
434     return emitOpError("has incorrect number of operands: expected ")
435            << cType.getNumElements() << ", but provided " << totalCount;
436   return success();
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // spirv.CompositeExtractOp
441 //===----------------------------------------------------------------------===//
442 
443 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
444                                       Value composite,
445                                       ArrayRef<int32_t> indices) {
446   auto indexAttr = builder.getI32ArrayAttr(indices);
447   auto elementType =
448       getElementType(composite.getType(), indexAttr, state.location);
449   if (!elementType) {
450     return;
451   }
452   build(builder, state, elementType, composite, indexAttr);
453 }
454 
455 ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
456                                              OperationState &result) {
457   OpAsmParser::UnresolvedOperand compositeInfo;
458   Attribute indicesAttr;
459   StringRef indicesAttrName =
460       spirv::CompositeExtractOp::getIndicesAttrName(result.name);
461   Type compositeType;
462   SMLoc attrLocation;
463 
464   if (parser.parseOperand(compositeInfo) ||
465       parser.getCurrentLocation(&attrLocation) ||
466       parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
467       parser.parseColonType(compositeType) ||
468       parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
469     return failure();
470   }
471 
472   Type resultType =
473       getElementType(compositeType, indicesAttr, parser, attrLocation);
474   if (!resultType) {
475     return failure();
476   }
477   result.addTypes(resultType);
478   return success();
479 }
480 
481 void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
482   printer << ' ' << getComposite() << getIndices() << " : "
483           << getComposite().getType();
484 }
485 
486 LogicalResult spirv::CompositeExtractOp::verify() {
487   auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
488   auto resultType =
489       getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
490   if (!resultType)
491     return failure();
492 
493   if (resultType != getType()) {
494     return emitOpError("invalid result type: expected ")
495            << resultType << " but provided " << getType();
496   }
497 
498   return success();
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // spirv.CompositeInsert
503 //===----------------------------------------------------------------------===//
504 
505 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
506                                      Value object, Value composite,
507                                      ArrayRef<int32_t> indices) {
508   auto indexAttr = builder.getI32ArrayAttr(indices);
509   build(builder, state, composite.getType(), object, composite, indexAttr);
510 }
511 
512 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
513                                             OperationState &result) {
514   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
515   Type objectType, compositeType;
516   Attribute indicesAttr;
517   StringRef indicesAttrName =
518       spirv::CompositeInsertOp::getIndicesAttrName(result.name);
519   auto loc = parser.getCurrentLocation();
520 
521   return failure(
522       parser.parseOperandList(operands, 2) ||
523       parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
524       parser.parseColonType(objectType) ||
525       parser.parseKeywordType("into", compositeType) ||
526       parser.resolveOperands(operands, {objectType, compositeType}, loc,
527                              result.operands) ||
528       parser.addTypesToList(compositeType, result.types));
529 }
530 
531 LogicalResult spirv::CompositeInsertOp::verify() {
532   auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
533   auto objectType =
534       getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
535   if (!objectType)
536     return failure();
537 
538   if (objectType != getObject().getType()) {
539     return emitOpError("object operand type should be ")
540            << objectType << ", but found " << getObject().getType();
541   }
542 
543   if (getComposite().getType() != getType()) {
544     return emitOpError("result type should be the same as "
545                        "the composite type, but found ")
546            << getComposite().getType() << " vs " << getType();
547   }
548 
549   return success();
550 }
551 
552 void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
553   printer << " " << getObject() << ", " << getComposite() << getIndices()
554           << " : " << getObject().getType() << " into "
555           << getComposite().getType();
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // spirv.Constant
560 //===----------------------------------------------------------------------===//
561 
562 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
563                                      OperationState &result) {
564   Attribute value;
565   StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
566   if (parser.parseAttribute(value, valueAttrName, result.attributes))
567     return failure();
568 
569   Type type = NoneType::get(parser.getContext());
570   if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
571     type = typedAttr.getType();
572   if (llvm::isa<NoneType, TensorType>(type)) {
573     if (parser.parseColonType(type))
574       return failure();
575   }
576 
577   return parser.addTypeToList(type, result.types);
578 }
579 
580 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
581   printer << ' ' << getValue();
582   if (llvm::isa<spirv::ArrayType>(getType()))
583     printer << " : " << getType();
584 }
585 
586 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
587                                         Type opType) {
588   if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
589     auto valueType = llvm::cast<TypedAttr>(value).getType();
590     if (valueType != opType)
591       return op.emitOpError("result type (")
592              << opType << ") does not match value type (" << valueType << ")";
593     return success();
594   }
595   if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
596     auto valueType = llvm::cast<TypedAttr>(value).getType();
597     if (valueType == opType)
598       return success();
599     auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
600     auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
601     if (!arrayType)
602       return op.emitOpError("result or element type (")
603              << opType << ") does not match value type (" << valueType
604              << "), must be the same or spirv.array";
605 
606     int numElements = arrayType.getNumElements();
607     auto opElemType = arrayType.getElementType();
608     while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
609       numElements *= t.getNumElements();
610       opElemType = t.getElementType();
611     }
612     if (!opElemType.isIntOrFloat())
613       return op.emitOpError("only support nested array result type");
614 
615     auto valueElemType = shapedType.getElementType();
616     if (valueElemType != opElemType) {
617       return op.emitOpError("result element type (")
618              << opElemType << ") does not match value element type ("
619              << valueElemType << ")";
620     }
621 
622     if (numElements != shapedType.getNumElements()) {
623       return op.emitOpError("result number of elements (")
624              << numElements << ") does not match value number of elements ("
625              << shapedType.getNumElements() << ")";
626     }
627     return success();
628   }
629   if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
630     auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
631     if (!arrayType)
632       return op.emitOpError(
633           "must have spirv.array result type for array value");
634     Type elemType = arrayType.getElementType();
635     for (Attribute element : arrayAttr.getValue()) {
636       // Verify array elements recursively.
637       if (failed(verifyConstantType(op, element, elemType)))
638         return failure();
639     }
640     return success();
641   }
642   return op.emitOpError("cannot have attribute: ") << value;
643 }
644 
645 LogicalResult spirv::ConstantOp::verify() {
646   // ODS already generates checks to make sure the result type is valid. We just
647   // need to additionally check that the value's attribute type is consistent
648   // with the result type.
649   return verifyConstantType(*this, getValueAttr(), getType());
650 }
651 
652 bool spirv::ConstantOp::isBuildableWith(Type type) {
653   // Must be valid SPIR-V type first.
654   if (!llvm::isa<spirv::SPIRVType>(type))
655     return false;
656 
657   if (isa<SPIRVDialect>(type.getDialect())) {
658     // TODO: support constant struct
659     return llvm::isa<spirv::ArrayType>(type);
660   }
661 
662   return true;
663 }
664 
665 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
666                                              OpBuilder &builder) {
667   if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
668     unsigned width = intType.getWidth();
669     if (width == 1)
670       return builder.create<spirv::ConstantOp>(loc, type,
671                                                builder.getBoolAttr(false));
672     return builder.create<spirv::ConstantOp>(
673         loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
674   }
675   if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
676     return builder.create<spirv::ConstantOp>(
677         loc, type, builder.getFloatAttr(floatType, 0.0));
678   }
679   if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
680     Type elemType = vectorType.getElementType();
681     if (llvm::isa<IntegerType>(elemType)) {
682       return builder.create<spirv::ConstantOp>(
683           loc, type,
684           DenseElementsAttr::get(vectorType,
685                                  IntegerAttr::get(elemType, 0).getValue()));
686     }
687     if (llvm::isa<FloatType>(elemType)) {
688       return builder.create<spirv::ConstantOp>(
689           loc, type,
690           DenseFPElementsAttr::get(vectorType,
691                                    FloatAttr::get(elemType, 0.0).getValue()));
692     }
693   }
694 
695   llvm_unreachable("unimplemented types for ConstantOp::getZero()");
696 }
697 
698 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
699                                             OpBuilder &builder) {
700   if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
701     unsigned width = intType.getWidth();
702     if (width == 1)
703       return builder.create<spirv::ConstantOp>(loc, type,
704                                                builder.getBoolAttr(true));
705     return builder.create<spirv::ConstantOp>(
706         loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
707   }
708   if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
709     return builder.create<spirv::ConstantOp>(
710         loc, type, builder.getFloatAttr(floatType, 1.0));
711   }
712   if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
713     Type elemType = vectorType.getElementType();
714     if (llvm::isa<IntegerType>(elemType)) {
715       return builder.create<spirv::ConstantOp>(
716           loc, type,
717           DenseElementsAttr::get(vectorType,
718                                  IntegerAttr::get(elemType, 1).getValue()));
719     }
720     if (llvm::isa<FloatType>(elemType)) {
721       return builder.create<spirv::ConstantOp>(
722           loc, type,
723           DenseFPElementsAttr::get(vectorType,
724                                    FloatAttr::get(elemType, 1.0).getValue()));
725     }
726   }
727 
728   llvm_unreachable("unimplemented types for ConstantOp::getOne()");
729 }
730 
731 void mlir::spirv::ConstantOp::getAsmResultNames(
732     llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
733   Type type = getType();
734 
735   SmallString<32> specialNameBuffer;
736   llvm::raw_svector_ostream specialName(specialNameBuffer);
737   specialName << "cst";
738 
739   IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
740 
741   if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
742     if (intTy && intTy.getWidth() == 1) {
743       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
744     }
745 
746     if (intTy.isSignless()) {
747       specialName << intCst.getInt();
748     } else if (intTy.isUnsigned()) {
749       specialName << intCst.getUInt();
750     } else {
751       specialName << intCst.getSInt();
752     }
753   }
754 
755   if (intTy || llvm::isa<FloatType>(type)) {
756     specialName << '_' << type;
757   }
758 
759   if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
760     specialName << "_vec_";
761     specialName << vecType.getDimSize(0);
762 
763     Type elementType = vecType.getElementType();
764 
765     if (llvm::isa<IntegerType>(elementType) ||
766         llvm::isa<FloatType>(elementType)) {
767       specialName << "x" << elementType;
768     }
769   }
770 
771   setNameFn(getResult(), specialName.str());
772 }
773 
774 void mlir::spirv::AddressOfOp::getAsmResultNames(
775     llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
776   SmallString<32> specialNameBuffer;
777   llvm::raw_svector_ostream specialName(specialNameBuffer);
778   specialName << getVariable() << "_addr";
779   setNameFn(getResult(), specialName.str());
780 }
781 
782 //===----------------------------------------------------------------------===//
783 // spirv.ControlBarrierOp
784 //===----------------------------------------------------------------------===//
785 
786 LogicalResult spirv::ControlBarrierOp::verify() {
787   return verifyMemorySemantics(getOperation(), getMemorySemantics());
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // spirv.EntryPoint
792 //===----------------------------------------------------------------------===//
793 
794 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
795                                 spirv::ExecutionModel executionModel,
796                                 spirv::FuncOp function,
797                                 ArrayRef<Attribute> interfaceVars) {
798   build(builder, state,
799         spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
800         SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
801 }
802 
803 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
804                                        OperationState &result) {
805   spirv::ExecutionModel execModel;
806   SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers;
807   SmallVector<Type, 0> idTypes;
808   SmallVector<Attribute, 4> interfaceVars;
809 
810   FlatSymbolRefAttr fn;
811   if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
812       parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
813     return failure();
814   }
815 
816   if (!parser.parseOptionalComma()) {
817     // Parse the interface variables
818     if (parser.parseCommaSeparatedList([&]() -> ParseResult {
819           // The name of the interface variable attribute isnt important
820           FlatSymbolRefAttr var;
821           NamedAttrList attrs;
822           if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
823             return failure();
824           interfaceVars.push_back(var);
825           return success();
826         }))
827       return failure();
828   }
829   result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
830                       parser.getBuilder().getArrayAttr(interfaceVars));
831   return success();
832 }
833 
834 void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
835   printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
836   printer.printSymbolName(getFn());
837   auto interfaceVars = getInterface().getValue();
838   if (!interfaceVars.empty()) {
839     printer << ", ";
840     llvm::interleaveComma(interfaceVars, printer);
841   }
842 }
843 
844 LogicalResult spirv::EntryPointOp::verify() {
845   // Checks for fn and interface symbol reference are done in spirv::ModuleOp
846   // verification.
847   return success();
848 }
849 
850 //===----------------------------------------------------------------------===//
851 // spirv.ExecutionMode
852 //===----------------------------------------------------------------------===//
853 
854 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
855                                    spirv::FuncOp function,
856                                    spirv::ExecutionMode executionMode,
857                                    ArrayRef<int32_t> params) {
858   build(builder, state, SymbolRefAttr::get(function),
859         spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
860         builder.getI32ArrayAttr(params));
861 }
862 
863 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
864                                           OperationState &result) {
865   spirv::ExecutionMode execMode;
866   Attribute fn;
867   if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
868       parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
869     return failure();
870   }
871 
872   SmallVector<int32_t, 4> values;
873   Type i32Type = parser.getBuilder().getIntegerType(32);
874   while (!parser.parseOptionalComma()) {
875     NamedAttrList attr;
876     Attribute value;
877     if (parser.parseAttribute(value, i32Type, "value", attr)) {
878       return failure();
879     }
880     values.push_back(llvm::cast<IntegerAttr>(value).getInt());
881   }
882   StringRef valuesAttrName =
883       spirv::ExecutionModeOp::getValuesAttrName(result.name);
884   result.addAttribute(valuesAttrName,
885                       parser.getBuilder().getI32ArrayAttr(values));
886   return success();
887 }
888 
889 void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
890   printer << " ";
891   printer.printSymbolName(getFn());
892   printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
893   auto values = this->getValues();
894   if (values.empty())
895     return;
896   printer << ", ";
897   llvm::interleaveComma(values, printer, [&](Attribute a) {
898     printer << llvm::cast<IntegerAttr>(a).getInt();
899   });
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // spirv.func
904 //===----------------------------------------------------------------------===//
905 
906 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
907   SmallVector<OpAsmParser::Argument> entryArgs;
908   SmallVector<DictionaryAttr> resultAttrs;
909   SmallVector<Type> resultTypes;
910   auto &builder = parser.getBuilder();
911 
912   // Parse the name as a symbol.
913   StringAttr nameAttr;
914   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
915                              result.attributes))
916     return failure();
917 
918   // Parse the function signature.
919   bool isVariadic = false;
920   if (function_interface_impl::parseFunctionSignature(
921           parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
922           resultAttrs))
923     return failure();
924 
925   SmallVector<Type> argTypes;
926   for (auto &arg : entryArgs)
927     argTypes.push_back(arg.type);
928   auto fnType = builder.getFunctionType(argTypes, resultTypes);
929   result.addAttribute(getFunctionTypeAttrName(result.name),
930                       TypeAttr::get(fnType));
931 
932   // Parse the optional function control keyword.
933   spirv::FunctionControl fnControl;
934   if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
935     return failure();
936 
937   // If additional attributes are present, parse them.
938   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
939     return failure();
940 
941   // Add the attributes to the function arguments.
942   assert(resultAttrs.size() == resultTypes.size());
943   function_interface_impl::addArgAndResultAttrs(
944       builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
945       getResAttrsAttrName(result.name));
946 
947   // Parse the optional function body.
948   auto *body = result.addRegion();
949   OptionalParseResult parseResult =
950       parser.parseOptionalRegion(*body, entryArgs);
951   return failure(parseResult.has_value() && failed(*parseResult));
952 }
953 
954 void spirv::FuncOp::print(OpAsmPrinter &printer) {
955   // Print function name, signature, and control.
956   printer << " ";
957   printer.printSymbolName(getSymName());
958   auto fnType = getFunctionType();
959   function_interface_impl::printFunctionSignature(
960       printer, *this, fnType.getInputs(),
961       /*isVariadic=*/false, fnType.getResults());
962   printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
963           << "\"";
964   function_interface_impl::printFunctionAttributes(
965       printer, *this,
966       {spirv::attributeName<spirv::FunctionControl>(),
967        getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
968        getFunctionControlAttrName()});
969 
970   // Print the body if this is not an external function.
971   Region &body = this->getBody();
972   if (!body.empty()) {
973     printer << ' ';
974     printer.printRegion(body, /*printEntryBlockArgs=*/false,
975                         /*printBlockTerminators=*/true);
976   }
977 }
978 
979 LogicalResult spirv::FuncOp::verifyType() {
980   FunctionType fnType = getFunctionType();
981   if (fnType.getNumResults() > 1)
982     return emitOpError("cannot have more than one result");
983 
984   auto hasDecorationAttr = [&](spirv::Decoration decoration,
985                                unsigned argIndex) {
986     auto func = llvm::cast<FunctionOpInterface>(getOperation());
987     for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
988       if (argAttr.getName() != spirv::DecorationAttr::name)
989         continue;
990       if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
991         return decAttr.getValue() == decoration;
992     }
993     return false;
994   };
995 
996   for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
997     Type param = fnType.getInputs()[i];
998     auto inputPtrType = dyn_cast<spirv::PointerType>(param);
999     if (!inputPtrType)
1000       continue;
1001 
1002     auto pointeePtrType =
1003         dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1004     if (pointeePtrType) {
1005       // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1006       // > If an OpFunctionParameter is a pointer (or contains a pointer)
1007       // > and the type it points to is a pointer in the PhysicalStorageBuffer
1008       // > storage class, the function parameter must be decorated with exactly
1009       // > one of AliasedPointer or RestrictPointer.
1010       if (pointeePtrType.getStorageClass() !=
1011           spirv::StorageClass::PhysicalStorageBuffer)
1012         continue;
1013 
1014       bool hasAliasedPtr =
1015           hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1016       bool hasRestrictPtr =
1017           hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1018       if (!hasAliasedPtr && !hasRestrictPtr)
1019         return emitOpError()
1020                << "with a pointer points to a physical buffer pointer must "
1021                   "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1022       continue;
1023     }
1024     // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1025     // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1026     // > the PhysicalStorageBuffer storage class, the function parameter must
1027     // > be decorated with exactly one of Aliased or Restrict.
1028     if (auto pointeeArrayType =
1029             dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1030       pointeePtrType =
1031           dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1032     } else {
1033       pointeePtrType = inputPtrType;
1034     }
1035 
1036     if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1037                                spirv::StorageClass::PhysicalStorageBuffer)
1038       continue;
1039 
1040     bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1041     bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1042     if (!hasAliased && !hasRestrict)
1043       return emitOpError() << "with physical buffer pointer must be decorated "
1044                               "either 'Aliased' or 'Restrict'";
1045   }
1046 
1047   return success();
1048 }
1049 
1050 LogicalResult spirv::FuncOp::verifyBody() {
1051   FunctionType fnType = getFunctionType();
1052 
1053   auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1054     if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1055       if (fnType.getNumResults() != 0)
1056         return retOp.emitOpError("cannot be used in functions returning value");
1057     } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1058       if (fnType.getNumResults() != 1)
1059         return retOp.emitOpError(
1060                    "returns 1 value but enclosing function requires ")
1061                << fnType.getNumResults() << " results";
1062 
1063       auto retOperandType = retOp.getValue().getType();
1064       auto fnResultType = fnType.getResult(0);
1065       if (retOperandType != fnResultType)
1066         return retOp.emitOpError(" return value's type (")
1067                << retOperandType << ") mismatch with function's result type ("
1068                << fnResultType << ")";
1069     }
1070     return WalkResult::advance();
1071   });
1072 
1073   // TODO: verify other bits like linkage type.
1074 
1075   return failure(walkResult.wasInterrupted());
1076 }
1077 
1078 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1079                           StringRef name, FunctionType type,
1080                           spirv::FunctionControl control,
1081                           ArrayRef<NamedAttribute> attrs) {
1082   state.addAttribute(SymbolTable::getSymbolAttrName(),
1083                      builder.getStringAttr(name));
1084   state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1085   state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1086                      builder.getAttr<spirv::FunctionControlAttr>(control));
1087   state.attributes.append(attrs.begin(), attrs.end());
1088   state.addRegion();
1089 }
1090 
1091 //===----------------------------------------------------------------------===//
1092 // spirv.GLFClampOp
1093 //===----------------------------------------------------------------------===//
1094 
1095 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1096                                      OperationState &result) {
1097   return parseOneResultSameOperandTypeOp(parser, result);
1098 }
1099 void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1100 
1101 //===----------------------------------------------------------------------===//
1102 // spirv.GLUClampOp
1103 //===----------------------------------------------------------------------===//
1104 
1105 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1106                                      OperationState &result) {
1107   return parseOneResultSameOperandTypeOp(parser, result);
1108 }
1109 void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // spirv.GLSClampOp
1113 //===----------------------------------------------------------------------===//
1114 
1115 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1116                                      OperationState &result) {
1117   return parseOneResultSameOperandTypeOp(parser, result);
1118 }
1119 void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1120 
1121 //===----------------------------------------------------------------------===//
1122 // spirv.GLFmaOp
1123 //===----------------------------------------------------------------------===//
1124 
1125 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1126   return parseOneResultSameOperandTypeOp(parser, result);
1127 }
1128 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1129 
1130 //===----------------------------------------------------------------------===//
1131 // spirv.GlobalVariable
1132 //===----------------------------------------------------------------------===//
1133 
1134 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1135                                     Type type, StringRef name,
1136                                     unsigned descriptorSet, unsigned binding) {
1137   build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1138   state.addAttribute(
1139       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1140       builder.getI32IntegerAttr(descriptorSet));
1141   state.addAttribute(
1142       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1143       builder.getI32IntegerAttr(binding));
1144 }
1145 
1146 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1147                                     Type type, StringRef name,
1148                                     spirv::BuiltIn builtin) {
1149   build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1150   state.addAttribute(
1151       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1152       builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1153 }
1154 
1155 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1156                                            OperationState &result) {
1157   // Parse variable name.
1158   StringAttr nameAttr;
1159   StringRef initializerAttrName =
1160       spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1161   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1162                              result.attributes)) {
1163     return failure();
1164   }
1165 
1166   // Parse optional initializer
1167   if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1168     FlatSymbolRefAttr initSymbol;
1169     if (parser.parseLParen() ||
1170         parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1171                               result.attributes) ||
1172         parser.parseRParen())
1173       return failure();
1174   }
1175 
1176   if (parseVariableDecorations(parser, result)) {
1177     return failure();
1178   }
1179 
1180   Type type;
1181   StringRef typeAttrName =
1182       spirv::GlobalVariableOp::getTypeAttrName(result.name);
1183   auto loc = parser.getCurrentLocation();
1184   if (parser.parseColonType(type)) {
1185     return failure();
1186   }
1187   if (!llvm::isa<spirv::PointerType>(type)) {
1188     return parser.emitError(loc, "expected spirv.ptr type");
1189   }
1190   result.addAttribute(typeAttrName, TypeAttr::get(type));
1191 
1192   return success();
1193 }
1194 
1195 void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
1196   SmallVector<StringRef, 4> elidedAttrs{
1197       spirv::attributeName<spirv::StorageClass>()};
1198 
1199   // Print variable name.
1200   printer << ' ';
1201   printer.printSymbolName(getSymName());
1202   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1203 
1204   StringRef initializerAttrName = this->getInitializerAttrName();
1205   // Print optional initializer
1206   if (auto initializer = this->getInitializer()) {
1207     printer << " " << initializerAttrName << '(';
1208     printer.printSymbolName(*initializer);
1209     printer << ')';
1210     elidedAttrs.push_back(initializerAttrName);
1211   }
1212 
1213   StringRef typeAttrName = this->getTypeAttrName();
1214   elidedAttrs.push_back(typeAttrName);
1215   spirv::printVariableDecorations(*this, printer, elidedAttrs);
1216   printer << " : " << getType();
1217 }
1218 
1219 LogicalResult spirv::GlobalVariableOp::verify() {
1220   if (!llvm::isa<spirv::PointerType>(getType()))
1221     return emitOpError("result must be of a !spv.ptr type");
1222 
1223   // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1224   // object. It cannot be Generic. It must be the same as the Storage Class
1225   // operand of the Result Type."
1226   // Also, Function storage class is reserved by spirv.Variable.
1227   auto storageClass = this->storageClass();
1228   if (storageClass == spirv::StorageClass::Generic ||
1229       storageClass == spirv::StorageClass::Function) {
1230     return emitOpError("storage class cannot be '")
1231            << stringifyStorageClass(storageClass) << "'";
1232   }
1233 
1234   if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1235           this->getInitializerAttrName())) {
1236     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
1237         (*this)->getParentOp(), init.getAttr());
1238     // TODO: Currently only variable initialization with specialization
1239     // constants and other variables is supported. They could be normal
1240     // constants in the module scope as well.
1241     if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1242                         spirv::SpecConstantCompositeOp>(initOp)) {
1243       return emitOpError("initializer must be result of a "
1244                          "spirv.SpecConstant or spirv.GlobalVariable or "
1245                          "spirv.SpecConstantCompositeOp op");
1246     }
1247   }
1248 
1249   return success();
1250 }
1251 
1252 //===----------------------------------------------------------------------===//
1253 // spirv.INTEL.SubgroupBlockRead
1254 //===----------------------------------------------------------------------===//
1255 
1256 LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1257   if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1258     return failure();
1259 
1260   return success();
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // spirv.INTEL.SubgroupBlockWrite
1265 //===----------------------------------------------------------------------===//
1266 
1267 ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
1268                                                     OperationState &result) {
1269   // Parse the storage class specification
1270   spirv::StorageClass storageClass;
1271   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
1272   auto loc = parser.getCurrentLocation();
1273   Type elementType;
1274   if (parseEnumStrAttr(storageClass, parser) ||
1275       parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1276       parser.parseType(elementType)) {
1277     return failure();
1278   }
1279 
1280   auto ptrType = spirv::PointerType::get(elementType, storageClass);
1281   if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1282     ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1283 
1284   if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1285                              result.operands)) {
1286     return failure();
1287   }
1288   return success();
1289 }
1290 
1291 void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
1292   printer << " " << getPtr() << ", " << getValue() << " : "
1293           << getValue().getType();
1294 }
1295 
1296 LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1297   if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1298     return failure();
1299 
1300   return success();
1301 }
1302 
1303 //===----------------------------------------------------------------------===//
1304 // spirv.IAddCarryOp
1305 //===----------------------------------------------------------------------===//
1306 
1307 LogicalResult spirv::IAddCarryOp::verify() {
1308   return ::verifyArithmeticExtendedBinaryOp(*this);
1309 }
1310 
1311 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1312                                       OperationState &result) {
1313   return ::parseArithmeticExtendedBinaryOp(parser, result);
1314 }
1315 
1316 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1317   ::printArithmeticExtendedBinaryOp(*this, printer);
1318 }
1319 
1320 //===----------------------------------------------------------------------===//
1321 // spirv.ISubBorrowOp
1322 //===----------------------------------------------------------------------===//
1323 
1324 LogicalResult spirv::ISubBorrowOp::verify() {
1325   return ::verifyArithmeticExtendedBinaryOp(*this);
1326 }
1327 
1328 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1329                                        OperationState &result) {
1330   return ::parseArithmeticExtendedBinaryOp(parser, result);
1331 }
1332 
1333 void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
1334   ::printArithmeticExtendedBinaryOp(*this, printer);
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // spirv.SMulExtended
1339 //===----------------------------------------------------------------------===//
1340 
1341 LogicalResult spirv::SMulExtendedOp::verify() {
1342   return ::verifyArithmeticExtendedBinaryOp(*this);
1343 }
1344 
1345 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1346                                          OperationState &result) {
1347   return ::parseArithmeticExtendedBinaryOp(parser, result);
1348 }
1349 
1350 void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
1351   ::printArithmeticExtendedBinaryOp(*this, printer);
1352 }
1353 
1354 //===----------------------------------------------------------------------===//
1355 // spirv.UMulExtended
1356 //===----------------------------------------------------------------------===//
1357 
1358 LogicalResult spirv::UMulExtendedOp::verify() {
1359   return ::verifyArithmeticExtendedBinaryOp(*this);
1360 }
1361 
1362 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1363                                          OperationState &result) {
1364   return ::parseArithmeticExtendedBinaryOp(parser, result);
1365 }
1366 
1367 void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
1368   ::printArithmeticExtendedBinaryOp(*this, printer);
1369 }
1370 
1371 //===----------------------------------------------------------------------===//
1372 // spirv.MemoryBarrierOp
1373 //===----------------------------------------------------------------------===//
1374 
1375 LogicalResult spirv::MemoryBarrierOp::verify() {
1376   return verifyMemorySemantics(getOperation(), getMemorySemantics());
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // spirv.module
1381 //===----------------------------------------------------------------------===//
1382 
1383 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1384                             std::optional<StringRef> name) {
1385   OpBuilder::InsertionGuard guard(builder);
1386   builder.createBlock(state.addRegion());
1387   if (name) {
1388     state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1389                             builder.getStringAttr(*name));
1390   }
1391 }
1392 
1393 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1394                             spirv::AddressingModel addressingModel,
1395                             spirv::MemoryModel memoryModel,
1396                             std::optional<VerCapExtAttr> vceTriple,
1397                             std::optional<StringRef> name) {
1398   state.addAttribute(
1399       "addressing_model",
1400       builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1401   state.addAttribute("memory_model",
1402                      builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1403   OpBuilder::InsertionGuard guard(builder);
1404   builder.createBlock(state.addRegion());
1405   if (vceTriple)
1406     state.addAttribute(getVCETripleAttrName(), *vceTriple);
1407   if (name)
1408     state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1409                        builder.getStringAttr(*name));
1410 }
1411 
1412 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1413                                    OperationState &result) {
1414   Region *body = result.addRegion();
1415 
1416   // If the name is present, parse it.
1417   StringAttr nameAttr;
1418   (void)parser.parseOptionalSymbolName(
1419       nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1420 
1421   // Parse attributes
1422   spirv::AddressingModel addrModel;
1423   spirv::MemoryModel memoryModel;
1424   if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1425                                                               result) ||
1426       spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1427                                                           result))
1428     return failure();
1429 
1430   if (succeeded(parser.parseOptionalKeyword("requires"))) {
1431     spirv::VerCapExtAttr vceTriple;
1432     if (parser.parseAttribute(vceTriple,
1433                               spirv::ModuleOp::getVCETripleAttrName(),
1434                               result.attributes))
1435       return failure();
1436   }
1437 
1438   if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1439       parser.parseRegion(*body, /*arguments=*/{}))
1440     return failure();
1441 
1442   // Make sure we have at least one block.
1443   if (body->empty())
1444     body->push_back(new Block());
1445 
1446   return success();
1447 }
1448 
1449 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1450   if (std::optional<StringRef> name = getName()) {
1451     printer << ' ';
1452     printer.printSymbolName(*name);
1453   }
1454 
1455   SmallVector<StringRef, 2> elidedAttrs;
1456 
1457   printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1458           << spirv::stringifyMemoryModel(getMemoryModel());
1459   auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1460   auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1461   elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1462                       mlir::SymbolTable::getSymbolAttrName()});
1463 
1464   if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1465     printer << " requires " << *triple;
1466     elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1467   }
1468 
1469   printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1470   printer << ' ';
1471   printer.printRegion(getRegion());
1472 }
1473 
1474 LogicalResult spirv::ModuleOp::verifyRegions() {
1475   Dialect *dialect = (*this)->getDialect();
1476   DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
1477       entryPoints;
1478   mlir::SymbolTable table(*this);
1479 
1480   for (auto &op : *getBody()) {
1481     if (op.getDialect() != dialect)
1482       return op.emitError("'spirv.module' can only contain spirv.* ops");
1483 
1484     // For EntryPoint op, check that the function and execution model is not
1485     // duplicated in EntryPointOps. Also verify that the interface specified
1486     // comes from globalVariables here to make this check cheaper.
1487     if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1488       auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1489       if (!funcOp) {
1490         return entryPointOp.emitError("function '")
1491                << entryPointOp.getFn() << "' not found in 'spirv.module'";
1492       }
1493       if (auto interface = entryPointOp.getInterface()) {
1494         for (Attribute varRef : interface) {
1495           auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1496           if (!varSymRef) {
1497             return entryPointOp.emitError(
1498                        "expected symbol reference for interface "
1499                        "specification instead of '")
1500                    << varRef;
1501           }
1502           auto variableOp =
1503               table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1504           if (!variableOp) {
1505             return entryPointOp.emitError("expected spirv.GlobalVariable "
1506                                           "symbol reference instead of'")
1507                    << varSymRef << "'";
1508           }
1509         }
1510       }
1511 
1512       auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1513           funcOp, entryPointOp.getExecutionModel());
1514       if (!entryPoints.try_emplace(key, entryPointOp).second)
1515         return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1516     } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1517       // If the function is external and does not have 'Import'
1518       // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1519       // LinkageAttributes is used to import external functions.
1520       auto linkageAttr = funcOp.getLinkageAttributes();
1521       auto hasImportLinkage =
1522           linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1523                           spirv::LinkageType::Import);
1524       if (funcOp.isExternal() && !hasImportLinkage)
1525         return op.emitError(
1526             "'spirv.module' cannot contain external functions "
1527             "without 'Import' linkage_attributes (LinkageAttributes)");
1528 
1529       // TODO: move this check to spirv.func.
1530       for (auto &block : funcOp)
1531         for (auto &op : block) {
1532           if (op.getDialect() != dialect)
1533             return op.emitError(
1534                 "functions in 'spirv.module' can only contain spirv.* ops");
1535         }
1536     }
1537   }
1538 
1539   return success();
1540 }
1541 
1542 //===----------------------------------------------------------------------===//
1543 // spirv.mlir.referenceof
1544 //===----------------------------------------------------------------------===//
1545 
1546 LogicalResult spirv::ReferenceOfOp::verify() {
1547   auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1548       (*this)->getParentOp(), getSpecConstAttr());
1549   Type constType;
1550 
1551   auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1552   if (specConstOp)
1553     constType = specConstOp.getDefaultValue().getType();
1554 
1555   auto specConstCompositeOp =
1556       dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1557   if (specConstCompositeOp)
1558     constType = specConstCompositeOp.getType();
1559 
1560   if (!specConstOp && !specConstCompositeOp)
1561     return emitOpError(
1562         "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1563 
1564   if (getReference().getType() != constType)
1565     return emitOpError("result type mismatch with the referenced "
1566                        "specialization constant's type");
1567 
1568   return success();
1569 }
1570 
1571 //===----------------------------------------------------------------------===//
1572 // spirv.SpecConstant
1573 //===----------------------------------------------------------------------===//
1574 
1575 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1576                                          OperationState &result) {
1577   StringAttr nameAttr;
1578   Attribute valueAttr;
1579   StringRef defaultValueAttrName =
1580       spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1581 
1582   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1583                              result.attributes))
1584     return failure();
1585 
1586   // Parse optional spec_id.
1587   if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1588     IntegerAttr specIdAttr;
1589     if (parser.parseLParen() ||
1590         parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1591         parser.parseRParen())
1592       return failure();
1593   }
1594 
1595   if (parser.parseEqual() ||
1596       parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1597     return failure();
1598 
1599   return success();
1600 }
1601 
1602 void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
1603   printer << ' ';
1604   printer.printSymbolName(getSymName());
1605   if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1606     printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1607   printer << " = " << getDefaultValue();
1608 }
1609 
1610 LogicalResult spirv::SpecConstantOp::verify() {
1611   if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1612     if (specID.getValue().isNegative())
1613       return emitOpError("SpecId cannot be negative");
1614 
1615   auto value = getDefaultValue();
1616   if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1617     // Make sure bitwidth is allowed.
1618     if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1619       return emitOpError("default value bitwidth disallowed");
1620     return success();
1621   }
1622   return emitOpError(
1623       "default value can only be a bool, integer, or float scalar");
1624 }
1625 
1626 //===----------------------------------------------------------------------===//
1627 // spirv.VectorShuffle
1628 //===----------------------------------------------------------------------===//
1629 
1630 LogicalResult spirv::VectorShuffleOp::verify() {
1631   VectorType resultType = llvm::cast<VectorType>(getType());
1632 
1633   size_t numResultElements = resultType.getNumElements();
1634   if (numResultElements != getComponents().size())
1635     return emitOpError("result type element count (")
1636            << numResultElements
1637            << ") mismatch with the number of component selectors ("
1638            << getComponents().size() << ")";
1639 
1640   size_t totalSrcElements =
1641       llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1642       llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1643 
1644   for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1645     uint32_t index = selector.getZExtValue();
1646     if (index >= totalSrcElements &&
1647         index != std::numeric_limits<uint32_t>().max())
1648       return emitOpError("component selector ")
1649              << index << " out of range: expected to be in [0, "
1650              << totalSrcElements << ") or 0xffffffff";
1651   }
1652   return success();
1653 }
1654 
1655 //===----------------------------------------------------------------------===//
1656 // spirv.MatrixTimesScalar
1657 //===----------------------------------------------------------------------===//
1658 
1659 LogicalResult spirv::MatrixTimesScalarOp::verify() {
1660   Type elementType =
1661       llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1662           .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
1663               [](auto matrixType) { return matrixType.getElementType(); })
1664           .Default([](Type) { return nullptr; });
1665 
1666   assert(elementType && "Unhandled type");
1667 
1668   // Check that the scalar type is the same as the matrix element type.
1669   if (getScalar().getType() != elementType)
1670     return emitOpError("input matrix components' type and scaling value must "
1671                        "have the same type");
1672 
1673   return success();
1674 }
1675 
1676 //===----------------------------------------------------------------------===//
1677 // spirv.Transpose
1678 //===----------------------------------------------------------------------===//
1679 
1680 LogicalResult spirv::TransposeOp::verify() {
1681   auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1682   auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1683 
1684   // Verify that the input and output matrices have correct shapes.
1685   if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1686     return emitError("input matrix rows count must be equal to "
1687                      "output matrix columns count");
1688 
1689   if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1690     return emitError("input matrix columns count must be equal to "
1691                      "output matrix rows count");
1692 
1693   // Verify that the input and output matrices have the same component type
1694   if (inputMatrix.getElementType() != resultMatrix.getElementType())
1695     return emitError("input and output matrices must have the same "
1696                      "component type");
1697 
1698   return success();
1699 }
1700 
1701 //===----------------------------------------------------------------------===//
1702 // spirv.MatrixTimesVector
1703 //===----------------------------------------------------------------------===//
1704 
1705 LogicalResult spirv::MatrixTimesVectorOp::verify() {
1706   auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1707   auto vectorType = llvm::cast<VectorType>(getVector().getType());
1708   auto resultType = llvm::cast<VectorType>(getType());
1709 
1710   if (matrixType.getNumColumns() != vectorType.getNumElements())
1711     return emitOpError("matrix columns (")
1712            << matrixType.getNumColumns() << ") must match vector operand size ("
1713            << vectorType.getNumElements() << ")";
1714 
1715   if (resultType.getNumElements() != matrixType.getNumRows())
1716     return emitOpError("result size (")
1717            << resultType.getNumElements() << ") must match the matrix rows ("
1718            << matrixType.getNumRows() << ")";
1719 
1720   if (matrixType.getElementType() != resultType.getElementType())
1721     return emitOpError("matrix and result element types must match");
1722 
1723   return success();
1724 }
1725 
1726 //===----------------------------------------------------------------------===//
1727 // spirv.VectorTimesMatrix
1728 //===----------------------------------------------------------------------===//
1729 
1730 LogicalResult spirv::VectorTimesMatrixOp::verify() {
1731   auto vectorType = llvm::cast<VectorType>(getVector().getType());
1732   auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1733   auto resultType = llvm::cast<VectorType>(getType());
1734 
1735   if (matrixType.getNumRows() != vectorType.getNumElements())
1736     return emitOpError("number of components in vector must equal the number "
1737                        "of components in each column in matrix");
1738 
1739   if (resultType.getNumElements() != matrixType.getNumColumns())
1740     return emitOpError("number of columns in matrix must equal the number of "
1741                        "components in result");
1742 
1743   if (matrixType.getElementType() != resultType.getElementType())
1744     return emitOpError("matrix must be a matrix with the same component type "
1745                        "as the component type in result");
1746 
1747   return success();
1748 }
1749 
1750 //===----------------------------------------------------------------------===//
1751 // spirv.MatrixTimesMatrix
1752 //===----------------------------------------------------------------------===//
1753 
1754 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1755   auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1756   auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1757   auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1758 
1759   // left matrix columns' count and right matrix rows' count must be equal
1760   if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1761     return emitError("left matrix columns' count must be equal to "
1762                      "the right matrix rows' count");
1763 
1764   // right and result matrices columns' count must be the same
1765   if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1766     return emitError(
1767         "right and result matrices must have equal columns' count");
1768 
1769   // right and result matrices component type must be the same
1770   if (rightMatrix.getElementType() != resultMatrix.getElementType())
1771     return emitError("right and result matrices' component type must"
1772                      " be the same");
1773 
1774   // left and result matrices component type must be the same
1775   if (leftMatrix.getElementType() != resultMatrix.getElementType())
1776     return emitError("left and result matrices' component type"
1777                      " must be the same");
1778 
1779   // left and result matrices rows count must be the same
1780   if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1781     return emitError("left and result matrices must have equal rows' count");
1782 
1783   return success();
1784 }
1785 
1786 //===----------------------------------------------------------------------===//
1787 // spirv.SpecConstantComposite
1788 //===----------------------------------------------------------------------===//
1789 
1790 ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
1791                                                   OperationState &result) {
1792 
1793   StringAttr compositeName;
1794   if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1795                              result.attributes))
1796     return failure();
1797 
1798   if (parser.parseLParen())
1799     return failure();
1800 
1801   SmallVector<Attribute, 4> constituents;
1802 
1803   do {
1804     // The name of the constituent attribute isn't important
1805     const char *attrName = "spec_const";
1806     FlatSymbolRefAttr specConstRef;
1807     NamedAttrList attrs;
1808 
1809     if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1810       return failure();
1811 
1812     constituents.push_back(specConstRef);
1813   } while (!parser.parseOptionalComma());
1814 
1815   if (parser.parseRParen())
1816     return failure();
1817 
1818   StringAttr compositeSpecConstituentsName =
1819       spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1820   result.addAttribute(compositeSpecConstituentsName,
1821                       parser.getBuilder().getArrayAttr(constituents));
1822 
1823   Type type;
1824   if (parser.parseColonType(type))
1825     return failure();
1826 
1827   StringAttr typeAttrName =
1828       spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1829   result.addAttribute(typeAttrName, TypeAttr::get(type));
1830 
1831   return success();
1832 }
1833 
1834 void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
1835   printer << " ";
1836   printer.printSymbolName(getSymName());
1837   printer << " (";
1838   auto constituents = this->getConstituents().getValue();
1839 
1840   if (!constituents.empty())
1841     llvm::interleaveComma(constituents, printer);
1842 
1843   printer << ") : " << getType();
1844 }
1845 
1846 LogicalResult spirv::SpecConstantCompositeOp::verify() {
1847   auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1848   auto constituents = this->getConstituents().getValue();
1849 
1850   if (!cType)
1851     return emitError("result type must be a composite type, but provided ")
1852            << getType();
1853 
1854   if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1855     return emitError("unsupported composite type  ") << cType;
1856   if (constituents.size() != cType.getNumElements())
1857     return emitError("has incorrect number of operands: expected ")
1858            << cType.getNumElements() << ", but provided "
1859            << constituents.size();
1860 
1861   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1862     auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1863 
1864     auto constituentSpecConstOp =
1865         dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1866             (*this)->getParentOp(), constituent.getAttr()));
1867 
1868     if (constituentSpecConstOp.getDefaultValue().getType() !=
1869         cType.getElementType(index))
1870       return emitError("has incorrect types of operands: expected ")
1871              << cType.getElementType(index) << ", but provided "
1872              << constituentSpecConstOp.getDefaultValue().getType();
1873   }
1874 
1875   return success();
1876 }
1877 
1878 //===----------------------------------------------------------------------===//
1879 // spirv.SpecConstantOperation
1880 //===----------------------------------------------------------------------===//
1881 
1882 ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
1883                                                   OperationState &result) {
1884   Region *body = result.addRegion();
1885 
1886   if (parser.parseKeyword("wraps"))
1887     return failure();
1888 
1889   body->push_back(new Block);
1890   Block &block = body->back();
1891   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1892 
1893   if (!wrappedOp)
1894     return failure();
1895 
1896   OpBuilder builder(parser.getContext());
1897   builder.setInsertionPointToEnd(&block);
1898   builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1899   result.location = wrappedOp->getLoc();
1900 
1901   result.addTypes(wrappedOp->getResult(0).getType());
1902 
1903   if (parser.parseOptionalAttrDict(result.attributes))
1904     return failure();
1905 
1906   return success();
1907 }
1908 
1909 void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
1910   printer << " wraps ";
1911   printer.printGenericOp(&getBody().front().front());
1912 }
1913 
1914 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1915   Block &block = getRegion().getBlocks().front();
1916 
1917   if (block.getOperations().size() != 2)
1918     return emitOpError("expected exactly 2 nested ops");
1919 
1920   Operation &enclosedOp = block.getOperations().front();
1921 
1922   if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
1923     return emitOpError("invalid enclosed op");
1924 
1925   for (auto operand : enclosedOp.getOperands())
1926     if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1927              spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1928       return emitOpError(
1929           "invalid operand, must be defined by a constant operation");
1930 
1931   return success();
1932 }
1933 
1934 //===----------------------------------------------------------------------===//
1935 // spirv.GL.FrexpStruct
1936 //===----------------------------------------------------------------------===//
1937 
1938 LogicalResult spirv::GLFrexpStructOp::verify() {
1939   spirv::StructType structTy =
1940       llvm::dyn_cast<spirv::StructType>(getResult().getType());
1941 
1942   if (structTy.getNumElements() != 2)
1943     return emitError("result type must be a struct type with two memebers");
1944 
1945   Type significandTy = structTy.getElementType(0);
1946   Type exponentTy = structTy.getElementType(1);
1947   VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1948   IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1949 
1950   Type operandTy = getOperand().getType();
1951   VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1952   FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1953 
1954   if (significandTy != operandTy)
1955     return emitError("member zero of the resulting struct type must be the "
1956                      "same type as the operand");
1957 
1958   if (exponentVecTy) {
1959     IntegerType componentIntTy =
1960         llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1961     if (!componentIntTy || componentIntTy.getWidth() != 32)
1962       return emitError("member one of the resulting struct type must"
1963                        "be a scalar or vector of 32 bit integer type");
1964   } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1965     return emitError("member one of the resulting struct type "
1966                      "must be a scalar or vector of 32 bit integer type");
1967   }
1968 
1969   // Check that the two member types have the same number of components
1970   if (operandVecTy && exponentVecTy &&
1971       (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1972     return success();
1973 
1974   if (operandFTy && exponentIntTy)
1975     return success();
1976 
1977   return emitError("member one of the resulting struct type must have the same "
1978                    "number of components as the operand type");
1979 }
1980 
1981 //===----------------------------------------------------------------------===//
1982 // spirv.GL.Ldexp
1983 //===----------------------------------------------------------------------===//
1984 
1985 LogicalResult spirv::GLLdexpOp::verify() {
1986   Type significandType = getX().getType();
1987   Type exponentType = getExp().getType();
1988 
1989   if (llvm::isa<FloatType>(significandType) !=
1990       llvm::isa<IntegerType>(exponentType))
1991     return emitOpError("operands must both be scalars or vectors");
1992 
1993   auto getNumElements = [](Type type) -> unsigned {
1994     if (auto vectorType = llvm::dyn_cast<VectorType>(type))
1995       return vectorType.getNumElements();
1996     return 1;
1997   };
1998 
1999   if (getNumElements(significandType) != getNumElements(exponentType))
2000     return emitOpError("operands must have the same number of elements");
2001 
2002   return success();
2003 }
2004 
2005 //===----------------------------------------------------------------------===//
2006 // spirv.ImageDrefGather
2007 //===----------------------------------------------------------------------===//
2008 
2009 LogicalResult spirv::ImageDrefGatherOp::verify() {
2010   VectorType resultType = llvm::cast<VectorType>(getResult().getType());
2011   auto sampledImageType =
2012       llvm::cast<spirv::SampledImageType>(getSampledimage().getType());
2013   auto imageType =
2014       llvm::cast<spirv::ImageType>(sampledImageType.getImageType());
2015 
2016   if (resultType.getNumElements() != 4)
2017     return emitOpError("result type must be a vector of four components");
2018 
2019   Type elementType = resultType.getElementType();
2020   Type sampledElementType = imageType.getElementType();
2021   if (!llvm::isa<NoneType>(sampledElementType) &&
2022       elementType != sampledElementType)
2023     return emitOpError(
2024         "the component type of result must be the same as sampled type of the "
2025         "underlying image type");
2026 
2027   spirv::Dim imageDim = imageType.getDim();
2028   spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
2029 
2030   if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
2031       imageDim != spirv::Dim::Rect)
2032     return emitOpError(
2033         "the Dim operand of the underlying image type must be 2D, Cube, or "
2034         "Rect");
2035 
2036   if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
2037     return emitOpError("the MS operand of the underlying image type must be 0");
2038 
2039   spirv::ImageOperandsAttr attr = getImageoperandsAttr();
2040   auto operandArguments = getOperandArguments();
2041 
2042   return verifyImageOperands(*this, attr, operandArguments);
2043 }
2044 
2045 //===----------------------------------------------------------------------===//
2046 // spirv.ShiftLeftLogicalOp
2047 //===----------------------------------------------------------------------===//
2048 
2049 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2050   return verifyShiftOp(*this);
2051 }
2052 
2053 //===----------------------------------------------------------------------===//
2054 // spirv.ShiftRightArithmeticOp
2055 //===----------------------------------------------------------------------===//
2056 
2057 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2058   return verifyShiftOp(*this);
2059 }
2060 
2061 //===----------------------------------------------------------------------===//
2062 // spirv.ShiftRightLogicalOp
2063 //===----------------------------------------------------------------------===//
2064 
2065 LogicalResult spirv::ShiftRightLogicalOp::verify() {
2066   return verifyShiftOp(*this);
2067 }
2068 
2069 //===----------------------------------------------------------------------===//
2070 // spirv.ImageQuerySize
2071 //===----------------------------------------------------------------------===//
2072 
2073 LogicalResult spirv::ImageQuerySizeOp::verify() {
2074   spirv::ImageType imageType =
2075       llvm::cast<spirv::ImageType>(getImage().getType());
2076   Type resultType = getResult().getType();
2077 
2078   spirv::Dim dim = imageType.getDim();
2079   spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
2080   spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
2081   switch (dim) {
2082   case spirv::Dim::Dim1D:
2083   case spirv::Dim::Dim2D:
2084   case spirv::Dim::Dim3D:
2085   case spirv::Dim::Cube:
2086     if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
2087         samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
2088         samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
2089       return emitError(
2090           "if Dim is 1D, 2D, 3D, or Cube, "
2091           "it must also have either an MS of 1 or a Sampled of 0 or 2");
2092     break;
2093   case spirv::Dim::Buffer:
2094   case spirv::Dim::Rect:
2095     break;
2096   default:
2097     return emitError("the Dim operand of the image type must "
2098                      "be 1D, 2D, 3D, Buffer, Cube, or Rect");
2099   }
2100 
2101   unsigned componentNumber = 0;
2102   switch (dim) {
2103   case spirv::Dim::Dim1D:
2104   case spirv::Dim::Buffer:
2105     componentNumber = 1;
2106     break;
2107   case spirv::Dim::Dim2D:
2108   case spirv::Dim::Cube:
2109   case spirv::Dim::Rect:
2110     componentNumber = 2;
2111     break;
2112   case spirv::Dim::Dim3D:
2113     componentNumber = 3;
2114     break;
2115   default:
2116     break;
2117   }
2118 
2119   if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
2120     componentNumber += 1;
2121 
2122   unsigned resultComponentNumber = 1;
2123   if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
2124     resultComponentNumber = resultVectorType.getNumElements();
2125 
2126   if (componentNumber != resultComponentNumber)
2127     return emitError("expected the result to have ")
2128            << componentNumber << " component(s), but found "
2129            << resultComponentNumber << " component(s)";
2130 
2131   return success();
2132 }
2133 
2134 //===----------------------------------------------------------------------===//
2135 // spirv.VectorTimesScalarOp
2136 //===----------------------------------------------------------------------===//
2137 
2138 LogicalResult spirv::VectorTimesScalarOp::verify() {
2139   if (getVector().getType() != getType())
2140     return emitOpError("vector operand and result type mismatch");
2141   auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2142   if (getScalar().getType() != scalarType)
2143     return emitOpError("scalar operand and result element type match");
2144   return success();
2145 }
2146