xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp (revision a5506a39e0ae8de77136334659b526e5f224850d)
1 //===- MemoryOps.cpp - MLIR SPIR-V Memory Ops  ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the memory operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 
16 #include "SPIRVOpUtils.h"
17 #include "SPIRVParsingUtils.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/Diagnostics.h"
20 
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Casting.h"
23 
24 using namespace mlir::spirv::AttrNames;
25 
26 namespace mlir::spirv {
27 
28 /// Parses optional memory access (a.k.a. memory operand) attributes attached to
29 /// a memory access operand/pointer. Specifically, parses the following syntax:
30 ///     (`[` memory-access `]`)?
31 /// where:
32 ///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
33 ///         integer-literal | `"NonTemporal"`
34 template <typename MemoryOpTy>
35 ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
36                                         OperationState &state) {
37   // Parse an optional list of attributes staring with '['
38   if (parser.parseOptionalLSquare()) {
39     // Nothing to do
40     return success();
41   }
42 
43   spirv::MemoryAccess memoryAccessAttr;
44   StringAttr memoryAccessAttrName =
45       MemoryOpTy::getMemoryAccessAttrName(state.name);
46   if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
47           memoryAccessAttr, parser, state, memoryAccessAttrName))
48     return failure();
49 
50   if (spirv::bitEnumContainsAll(memoryAccessAttr,
51                                 spirv::MemoryAccess::Aligned)) {
52     // Parse integer attribute for alignment.
53     Attribute alignmentAttr;
54     StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
55     Type i32Type = parser.getBuilder().getIntegerType(32);
56     if (parser.parseComma() ||
57         parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
58                               state.attributes)) {
59       return failure();
60     }
61   }
62   return parser.parseRSquare();
63 }
64 
65 // TODO Make sure to merge this and the previous function into one template
66 // parameterized by memory access attribute name and alignment. Doing so now
67 // results in VS2017 in producing an internal error (at the call site) that's
68 // not detailed enough to understand what is happening.
69 template <typename MemoryOpTy>
70 static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
71                                                      OperationState &state) {
72   // Parse an optional list of attributes staring with '['
73   if (parser.parseOptionalLSquare()) {
74     // Nothing to do
75     return success();
76   }
77 
78   spirv::MemoryAccess memoryAccessAttr;
79   StringRef memoryAccessAttrName =
80       MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
81   if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
82           memoryAccessAttr, parser, state, memoryAccessAttrName))
83     return failure();
84 
85   if (spirv::bitEnumContainsAll(memoryAccessAttr,
86                                 spirv::MemoryAccess::Aligned)) {
87     // Parse integer attribute for alignment.
88     Attribute alignmentAttr;
89     StringAttr alignmentAttrName =
90         MemoryOpTy::getSourceAlignmentAttrName(state.name);
91     Type i32Type = parser.getBuilder().getIntegerType(32);
92     if (parser.parseComma() ||
93         parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
94                               state.attributes)) {
95       return failure();
96     }
97   }
98   return parser.parseRSquare();
99 }
100 
101 // TODO Make sure to merge this and the previous function into one template
102 // parameterized by memory access attribute name and alignment. Doing so now
103 // results in VS2017 in producing an internal error (at the call site) that's
104 // not detailed enough to understand what is happening.
105 template <typename MemoryOpTy>
106 static void printSourceMemoryAccessAttribute(
107     MemoryOpTy memoryOp, OpAsmPrinter &printer,
108     SmallVectorImpl<StringRef> &elidedAttrs,
109     std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
110     std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
111 
112   printer << ", ";
113 
114   // Print optional memory access attribute.
115   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
116                                               : memoryOp.getMemoryAccess())) {
117     elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
118 
119     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
120 
121     if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
122       // Print integer alignment attribute.
123       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
124                                                : memoryOp.getAlignment())) {
125         elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
126         printer << ", " << *alignment;
127       }
128     }
129     printer << "]";
130   }
131   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
132 }
133 
134 template <typename MemoryOpTy>
135 static void printMemoryAccessAttribute(
136     MemoryOpTy memoryOp, OpAsmPrinter &printer,
137     SmallVectorImpl<StringRef> &elidedAttrs,
138     std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
139     std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
140   // Print optional memory access attribute.
141   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
142                                               : memoryOp.getMemoryAccess())) {
143     elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
144 
145     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
146 
147     if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
148       // Print integer alignment attribute.
149       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
150                                                : memoryOp.getAlignment())) {
151         elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
152         printer << ", " << *alignment;
153       }
154     }
155     printer << "]";
156   }
157   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
158 }
159 
160 template <typename LoadStoreOpTy>
161 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
162                                                    Value val) {
163   // ODS already checks ptr is spirv::PointerType. Just check that the pointee
164   // type of the pointer and the type of the value are the same
165   //
166   // TODO: Check that the value type satisfies restrictions of
167   // SPIR-V OpLoad/OpStore operations
168   if (val.getType() !=
169       llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
170     return op.emitOpError("mismatch in result type and pointer type");
171   }
172   return success();
173 }
174 
175 template <typename MemoryOpTy>
176 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
177   // ODS checks for attributes values. Just need to verify that if the
178   // memory-access attribute is Aligned, then the alignment attribute must be
179   // present.
180   auto *op = memoryOp.getOperation();
181   auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
182   if (!memAccessAttr) {
183     // Alignment attribute shouldn't be present if memory access attribute is
184     // not present.
185     if (op->getAttr(memoryOp.getAlignmentAttrName())) {
186       return memoryOp.emitOpError(
187           "invalid alignment specification without aligned memory access "
188           "specification");
189     }
190     return success();
191   }
192 
193   auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
194 
195   if (!memAccess) {
196     return memoryOp.emitOpError("invalid memory access specifier: ")
197            << memAccessAttr;
198   }
199 
200   if (spirv::bitEnumContainsAll(memAccess.getValue(),
201                                 spirv::MemoryAccess::Aligned)) {
202     if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
203       return memoryOp.emitOpError("missing alignment value");
204     }
205   } else {
206     if (op->getAttr(memoryOp.getAlignmentAttrName())) {
207       return memoryOp.emitOpError(
208           "invalid alignment specification with non-aligned memory access "
209           "specification");
210     }
211   }
212   return success();
213 }
214 
215 // TODO Make sure to merge this and the previous function into one template
216 // parameterized by memory access attribute name and alignment. Doing so now
217 // results in VS2017 in producing an internal error (at the call site) that's
218 // not detailed enough to understand what is happening.
219 template <typename MemoryOpTy>
220 static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
221   // ODS checks for attributes values. Just need to verify that if the
222   // memory-access attribute is Aligned, then the alignment attribute must be
223   // present.
224   auto *op = memoryOp.getOperation();
225   auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
226   if (!memAccessAttr) {
227     // Alignment attribute shouldn't be present if memory access attribute is
228     // not present.
229     if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
230       return memoryOp.emitOpError(
231           "invalid alignment specification without aligned memory access "
232           "specification");
233     }
234     return success();
235   }
236 
237   auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
238 
239   if (!memAccess) {
240     return memoryOp.emitOpError("invalid memory access specifier: ")
241            << memAccess;
242   }
243 
244   if (spirv::bitEnumContainsAll(memAccess.getValue(),
245                                 spirv::MemoryAccess::Aligned)) {
246     if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
247       return memoryOp.emitOpError("missing alignment value");
248     }
249   } else {
250     if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
251       return memoryOp.emitOpError(
252           "invalid alignment specification with non-aligned memory access "
253           "specification");
254     }
255   }
256   return success();
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // spirv.AccessChainOp
261 //===----------------------------------------------------------------------===//
262 
263 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
264   auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
265   if (!ptrType) {
266     emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
267                        "to composite type, but provided ")
268         << type;
269     return nullptr;
270   }
271 
272   auto resultType = ptrType.getPointeeType();
273   auto resultStorageClass = ptrType.getStorageClass();
274   int32_t index = 0;
275 
276   for (auto indexSSA : indices) {
277     auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
278     if (!cType) {
279       emitError(
280           baseLoc,
281           "'spirv.AccessChain' op cannot extract from non-composite type ")
282           << resultType << " with index " << index;
283       return nullptr;
284     }
285     index = 0;
286     if (llvm::isa<spirv::StructType>(resultType)) {
287       Operation *op = indexSSA.getDefiningOp();
288       if (!op) {
289         emitError(baseLoc, "'spirv.AccessChain' op index must be an "
290                            "integer spirv.Constant to access "
291                            "element of spirv.struct");
292         return nullptr;
293       }
294 
295       // TODO: this should be relaxed to allow
296       // integer literals of other bitwidths.
297       if (failed(spirv::extractValueFromConstOp(op, index))) {
298         emitError(
299             baseLoc,
300             "'spirv.AccessChain' index must be an integer spirv.Constant to "
301             "access element of spirv.struct, but provided ")
302             << op->getName();
303         return nullptr;
304       }
305       if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
306         emitError(baseLoc, "'spirv.AccessChain' op index ")
307             << index << " out of bounds for " << resultType;
308         return nullptr;
309       }
310     }
311     resultType = cType.getElementType(index);
312   }
313   return spirv::PointerType::get(resultType, resultStorageClass);
314 }
315 
316 void AccessChainOp::build(OpBuilder &builder, OperationState &state,
317                           Value basePtr, ValueRange indices) {
318   auto type = getElementPtrType(basePtr.getType(), indices, state.location);
319   assert(type && "Unable to deduce return type based on basePtr and indices");
320   build(builder, state, type, basePtr, indices);
321 }
322 
323 template <typename Op>
324 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
325   printer << ' ' << op.getBasePtr() << '[' << indices
326           << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
327 }
328 
329 template <typename Op>
330 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
331   auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
332                                       indices, accessChainOp.getLoc());
333   if (!resultType)
334     return failure();
335 
336   auto providedResultType =
337       llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
338   if (!providedResultType)
339     return accessChainOp.emitOpError(
340                "result type must be a pointer, but provided")
341            << providedResultType;
342 
343   if (resultType != providedResultType)
344     return accessChainOp.emitOpError("invalid result type: expected ")
345            << resultType << ", but provided " << providedResultType;
346 
347   return success();
348 }
349 
350 LogicalResult AccessChainOp::verify() {
351   return verifyAccessChain(*this, getIndices());
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // spirv.LoadOp
356 //===----------------------------------------------------------------------===//
357 
358 void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
359                    MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
360   auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
361   build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
362         alignment);
363 }
364 
365 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
366   // Parse the storage class specification
367   spirv::StorageClass storageClass;
368   OpAsmParser::UnresolvedOperand ptrInfo;
369   Type elementType;
370   if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
371       parseMemoryAccessAttributes<LoadOp>(parser, result) ||
372       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
373       parser.parseType(elementType)) {
374     return failure();
375   }
376 
377   auto ptrType = spirv::PointerType::get(elementType, storageClass);
378   if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
379     return failure();
380   }
381 
382   result.addTypes(elementType);
383   return success();
384 }
385 
386 void LoadOp::print(OpAsmPrinter &printer) {
387   SmallVector<StringRef, 4> elidedAttrs;
388   StringRef sc = stringifyStorageClass(
389       llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
390   printer << " \"" << sc << "\" " << getPtr();
391 
392   printMemoryAccessAttribute(*this, printer, elidedAttrs);
393 
394   printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
395   printer << " : " << getType();
396 }
397 
398 LogicalResult LoadOp::verify() {
399   // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
400   // type with fixed size; i.e., it cannot be, nor include, any
401   // OpTypeRuntimeArray types."
402   if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
403     return failure();
404   }
405   return verifyMemoryAccessAttribute(*this);
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // spirv.StoreOp
410 //===----------------------------------------------------------------------===//
411 
412 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
413   // Parse the storage class specification
414   spirv::StorageClass storageClass;
415   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
416   auto loc = parser.getCurrentLocation();
417   Type elementType;
418   if (parseEnumStrAttr(storageClass, parser) ||
419       parser.parseOperandList(operandInfo, 2) ||
420       parseMemoryAccessAttributes<StoreOp>(parser, result) ||
421       parser.parseColon() || parser.parseType(elementType)) {
422     return failure();
423   }
424 
425   auto ptrType = spirv::PointerType::get(elementType, storageClass);
426   if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
427                              result.operands)) {
428     return failure();
429   }
430   return success();
431 }
432 
433 void StoreOp::print(OpAsmPrinter &printer) {
434   SmallVector<StringRef, 4> elidedAttrs;
435   StringRef sc = stringifyStorageClass(
436       llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
437   printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
438 
439   printMemoryAccessAttribute(*this, printer, elidedAttrs);
440 
441   printer << " : " << getValue().getType();
442   printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
443 }
444 
445 LogicalResult StoreOp::verify() {
446   // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
447   // OpTypePointer whose Type operand is the same as the type of Object."
448   if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
449     return failure();
450   return verifyMemoryAccessAttribute(*this);
451 }
452 
453 //===----------------------------------------------------------------------===//
454 // spirv.CopyMemory
455 //===----------------------------------------------------------------------===//
456 
457 void CopyMemoryOp::print(OpAsmPrinter &printer) {
458   printer << ' ';
459 
460   StringRef targetStorageClass = stringifyStorageClass(
461       llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
462   printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
463 
464   StringRef sourceStorageClass = stringifyStorageClass(
465       llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
466   printer << " \"" << sourceStorageClass << "\" " << getSource();
467 
468   SmallVector<StringRef, 4> elidedAttrs;
469   printMemoryAccessAttribute(*this, printer, elidedAttrs);
470   printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
471                                    getSourceMemoryAccess(),
472                                    getSourceAlignment());
473 
474   printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
475 
476   Type pointeeType =
477       llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
478   printer << " : " << pointeeType;
479 }
480 
481 ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
482   spirv::StorageClass targetStorageClass;
483   OpAsmParser::UnresolvedOperand targetPtrInfo;
484 
485   spirv::StorageClass sourceStorageClass;
486   OpAsmParser::UnresolvedOperand sourcePtrInfo;
487 
488   Type elementType;
489 
490   if (parseEnumStrAttr(targetStorageClass, parser) ||
491       parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
492       parseEnumStrAttr(sourceStorageClass, parser) ||
493       parser.parseOperand(sourcePtrInfo) ||
494       parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
495     return failure();
496   }
497 
498   if (!parser.parseOptionalComma()) {
499     // Parse 2nd memory access attributes.
500     if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
501       return failure();
502     }
503   }
504 
505   if (parser.parseColon() || parser.parseType(elementType))
506     return failure();
507 
508   if (parser.parseOptionalAttrDict(result.attributes))
509     return failure();
510 
511   auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
512   auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
513 
514   if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
515       parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
516     return failure();
517   }
518 
519   return success();
520 }
521 
522 LogicalResult CopyMemoryOp::verify() {
523   Type targetType =
524       llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
525 
526   Type sourceType =
527       llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
528 
529   if (targetType != sourceType)
530     return emitOpError("both operands must be pointers to the same type");
531 
532   if (failed(verifyMemoryAccessAttribute(*this)))
533     return failure();
534 
535   // TODO - According to the spec:
536   //
537   // If two masks are present, the first applies to Target and cannot include
538   // MakePointerVisible, and the second applies to Source and cannot include
539   // MakePointerAvailable.
540   //
541   // Add such verification here.
542 
543   return verifySourceMemoryAccessAttribute(*this);
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // spirv.InBoundsPtrAccessChainOp
548 //===----------------------------------------------------------------------===//
549 
550 void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
551                                      Value basePtr, Value element,
552                                      ValueRange indices) {
553   auto type = getElementPtrType(basePtr.getType(), indices, state.location);
554   assert(type && "Unable to deduce return type based on basePtr and indices");
555   build(builder, state, type, basePtr, element, indices);
556 }
557 
558 LogicalResult InBoundsPtrAccessChainOp::verify() {
559   return verifyAccessChain(*this, getIndices());
560 }
561 
562 //===----------------------------------------------------------------------===//
563 // spirv.PtrAccessChainOp
564 //===----------------------------------------------------------------------===//
565 
566 void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
567                              Value basePtr, Value element, ValueRange indices) {
568   auto type = getElementPtrType(basePtr.getType(), indices, state.location);
569   assert(type && "Unable to deduce return type based on basePtr and indices");
570   build(builder, state, type, basePtr, element, indices);
571 }
572 
573 LogicalResult PtrAccessChainOp::verify() {
574   return verifyAccessChain(*this, getIndices());
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // spirv.Variable
579 //===----------------------------------------------------------------------===//
580 
581 ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
582   // Parse optional initializer
583   std::optional<OpAsmParser::UnresolvedOperand> initInfo;
584   if (succeeded(parser.parseOptionalKeyword("init"))) {
585     initInfo = OpAsmParser::UnresolvedOperand();
586     if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
587         parser.parseRParen())
588       return failure();
589   }
590 
591   if (parseVariableDecorations(parser, result)) {
592     return failure();
593   }
594 
595   // Parse result pointer type
596   Type type;
597   if (parser.parseColon())
598     return failure();
599   auto loc = parser.getCurrentLocation();
600   if (parser.parseType(type))
601     return failure();
602 
603   auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
604   if (!ptrType)
605     return parser.emitError(loc, "expected spirv.ptr type");
606   result.addTypes(ptrType);
607 
608   // Resolve the initializer operand
609   if (initInfo) {
610     if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
611                               result.operands))
612       return failure();
613   }
614 
615   auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
616       ptrType.getStorageClass());
617   result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
618 
619   return success();
620 }
621 
622 void VariableOp::print(OpAsmPrinter &printer) {
623   SmallVector<StringRef, 4> elidedAttrs{
624       spirv::attributeName<spirv::StorageClass>()};
625   // Print optional initializer
626   if (getNumOperands() != 0)
627     printer << " init(" << getInitializer() << ")";
628 
629   printVariableDecorations(*this, printer, elidedAttrs);
630   printer << " : " << getType();
631 }
632 
633 LogicalResult VariableOp::verify() {
634   // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
635   // object. It cannot be Generic. It must be the same as the Storage Class
636   // operand of the Result Type."
637   if (getStorageClass() != spirv::StorageClass::Function) {
638     return emitOpError(
639         "can only be used to model function-level variables. Use "
640         "spirv.GlobalVariable for module-level variables.");
641   }
642 
643   auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
644   if (getStorageClass() != pointerType.getStorageClass())
645     return emitOpError(
646         "storage class must match result pointer's storage class");
647 
648   if (getNumOperands() != 0) {
649     // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
650     // a global (module scope) OpVariable instruction".
651     auto *initOp = getOperand(0).getDefiningOp();
652     if (!initOp || !isa<spirv::ConstantOp,    // for normal constant
653                         spirv::ReferenceOfOp, // for spec constant
654                         spirv::AddressOfOp>(initOp))
655       return emitOpError("initializer must be the result of a "
656                          "constant or spirv.GlobalVariable op");
657   }
658 
659   auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
660     return op->getAttr(
661         llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
662   };
663 
664   // TODO: generate these strings using ODS.
665   for (auto decoration :
666        {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
667         spirv::Decoration::BuiltIn}) {
668     if (auto attr = getDecorationAttr(decoration))
669       return emitOpError("cannot have '")
670              << llvm::convertToSnakeFromCamelCase(
671                     stringifyDecoration(decoration))
672              << "' attribute (only allowed in spirv.GlobalVariable)";
673   }
674 
675   // From SPV_KHR_physical_storage_buffer:
676   // > If an OpVariable's pointee type is a pointer (or array of pointers) in
677   // > PhysicalStorageBuffer storage class, then the variable must be decorated
678   // > with exactly one of AliasedPointer or RestrictPointer.
679   auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
680   if (!pointeePtrType) {
681     if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
682       pointeePtrType =
683           dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
684     }
685   }
686 
687   if (pointeePtrType && pointeePtrType.getStorageClass() ==
688                             spirv::StorageClass::PhysicalStorageBuffer) {
689     bool hasAliasedPtr =
690         getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
691     bool hasRestrictPtr =
692         getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
693 
694     if (!hasAliasedPtr && !hasRestrictPtr)
695       return emitOpError() << " with physical buffer pointer must be decorated "
696                               "either 'AliasedPointer' or 'RestrictPointer'";
697 
698     if (hasAliasedPtr && hasRestrictPtr)
699       return emitOpError()
700              << " with physical buffer pointer must have exactly one "
701                 "aliasing decoration";
702   }
703 
704   return success();
705 }
706 
707 } // namespace mlir::spirv
708