1 //===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Defines the memory operations in the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 15 16 #include "SPIRVOpUtils.h" 17 #include "SPIRVParsingUtils.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/IR/Diagnostics.h" 20 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/Support/Casting.h" 23 24 using namespace mlir::spirv::AttrNames; 25 26 namespace mlir::spirv { 27 28 /// Parses optional memory access (a.k.a. memory operand) attributes attached to 29 /// a memory access operand/pointer. Specifically, parses the following syntax: 30 /// (`[` memory-access `]`)? 31 /// where: 32 /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` 33 /// integer-literal | `"NonTemporal"` 34 template <typename MemoryOpTy> 35 ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, 36 OperationState &state) { 37 // Parse an optional list of attributes staring with '[' 38 if (parser.parseOptionalLSquare()) { 39 // Nothing to do 40 return success(); 41 } 42 43 spirv::MemoryAccess memoryAccessAttr; 44 StringAttr memoryAccessAttrName = 45 MemoryOpTy::getMemoryAccessAttrName(state.name); 46 if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>( 47 memoryAccessAttr, parser, state, memoryAccessAttrName)) 48 return failure(); 49 50 if (spirv::bitEnumContainsAll(memoryAccessAttr, 51 spirv::MemoryAccess::Aligned)) { 52 // Parse integer attribute for alignment. 53 Attribute alignmentAttr; 54 StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name); 55 Type i32Type = parser.getBuilder().getIntegerType(32); 56 if (parser.parseComma() || 57 parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName, 58 state.attributes)) { 59 return failure(); 60 } 61 } 62 return parser.parseRSquare(); 63 } 64 65 // TODO Make sure to merge this and the previous function into one template 66 // parameterized by memory access attribute name and alignment. Doing so now 67 // results in VS2017 in producing an internal error (at the call site) that's 68 // not detailed enough to understand what is happening. 69 template <typename MemoryOpTy> 70 static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, 71 OperationState &state) { 72 // Parse an optional list of attributes staring with '[' 73 if (parser.parseOptionalLSquare()) { 74 // Nothing to do 75 return success(); 76 } 77 78 spirv::MemoryAccess memoryAccessAttr; 79 StringRef memoryAccessAttrName = 80 MemoryOpTy::getSourceMemoryAccessAttrName(state.name); 81 if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>( 82 memoryAccessAttr, parser, state, memoryAccessAttrName)) 83 return failure(); 84 85 if (spirv::bitEnumContainsAll(memoryAccessAttr, 86 spirv::MemoryAccess::Aligned)) { 87 // Parse integer attribute for alignment. 88 Attribute alignmentAttr; 89 StringAttr alignmentAttrName = 90 MemoryOpTy::getSourceAlignmentAttrName(state.name); 91 Type i32Type = parser.getBuilder().getIntegerType(32); 92 if (parser.parseComma() || 93 parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName, 94 state.attributes)) { 95 return failure(); 96 } 97 } 98 return parser.parseRSquare(); 99 } 100 101 // TODO Make sure to merge this and the previous function into one template 102 // parameterized by memory access attribute name and alignment. Doing so now 103 // results in VS2017 in producing an internal error (at the call site) that's 104 // not detailed enough to understand what is happening. 105 template <typename MemoryOpTy> 106 static void printSourceMemoryAccessAttribute( 107 MemoryOpTy memoryOp, OpAsmPrinter &printer, 108 SmallVectorImpl<StringRef> &elidedAttrs, 109 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt, 110 std::optional<uint32_t> alignmentAttrValue = std::nullopt) { 111 112 printer << ", "; 113 114 // Print optional memory access attribute. 115 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue 116 : memoryOp.getMemoryAccess())) { 117 elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName()); 118 119 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; 120 121 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { 122 // Print integer alignment attribute. 123 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue 124 : memoryOp.getAlignment())) { 125 elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName()); 126 printer << ", " << *alignment; 127 } 128 } 129 printer << "]"; 130 } 131 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); 132 } 133 134 template <typename MemoryOpTy> 135 static void printMemoryAccessAttribute( 136 MemoryOpTy memoryOp, OpAsmPrinter &printer, 137 SmallVectorImpl<StringRef> &elidedAttrs, 138 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt, 139 std::optional<uint32_t> alignmentAttrValue = std::nullopt) { 140 // Print optional memory access attribute. 141 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue 142 : memoryOp.getMemoryAccess())) { 143 elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName()); 144 145 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; 146 147 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { 148 // Print integer alignment attribute. 149 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue 150 : memoryOp.getAlignment())) { 151 elidedAttrs.push_back(memoryOp.getAlignmentAttrName()); 152 printer << ", " << *alignment; 153 } 154 } 155 printer << "]"; 156 } 157 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); 158 } 159 160 template <typename LoadStoreOpTy> 161 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, 162 Value val) { 163 // ODS already checks ptr is spirv::PointerType. Just check that the pointee 164 // type of the pointer and the type of the value are the same 165 // 166 // TODO: Check that the value type satisfies restrictions of 167 // SPIR-V OpLoad/OpStore operations 168 if (val.getType() != 169 llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) { 170 return op.emitOpError("mismatch in result type and pointer type"); 171 } 172 return success(); 173 } 174 175 template <typename MemoryOpTy> 176 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { 177 // ODS checks for attributes values. Just need to verify that if the 178 // memory-access attribute is Aligned, then the alignment attribute must be 179 // present. 180 auto *op = memoryOp.getOperation(); 181 auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName()); 182 if (!memAccessAttr) { 183 // Alignment attribute shouldn't be present if memory access attribute is 184 // not present. 185 if (op->getAttr(memoryOp.getAlignmentAttrName())) { 186 return memoryOp.emitOpError( 187 "invalid alignment specification without aligned memory access " 188 "specification"); 189 } 190 return success(); 191 } 192 193 auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr); 194 195 if (!memAccess) { 196 return memoryOp.emitOpError("invalid memory access specifier: ") 197 << memAccessAttr; 198 } 199 200 if (spirv::bitEnumContainsAll(memAccess.getValue(), 201 spirv::MemoryAccess::Aligned)) { 202 if (!op->getAttr(memoryOp.getAlignmentAttrName())) { 203 return memoryOp.emitOpError("missing alignment value"); 204 } 205 } else { 206 if (op->getAttr(memoryOp.getAlignmentAttrName())) { 207 return memoryOp.emitOpError( 208 "invalid alignment specification with non-aligned memory access " 209 "specification"); 210 } 211 } 212 return success(); 213 } 214 215 // TODO Make sure to merge this and the previous function into one template 216 // parameterized by memory access attribute name and alignment. Doing so now 217 // results in VS2017 in producing an internal error (at the call site) that's 218 // not detailed enough to understand what is happening. 219 template <typename MemoryOpTy> 220 static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) { 221 // ODS checks for attributes values. Just need to verify that if the 222 // memory-access attribute is Aligned, then the alignment attribute must be 223 // present. 224 auto *op = memoryOp.getOperation(); 225 auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName()); 226 if (!memAccessAttr) { 227 // Alignment attribute shouldn't be present if memory access attribute is 228 // not present. 229 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) { 230 return memoryOp.emitOpError( 231 "invalid alignment specification without aligned memory access " 232 "specification"); 233 } 234 return success(); 235 } 236 237 auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr); 238 239 if (!memAccess) { 240 return memoryOp.emitOpError("invalid memory access specifier: ") 241 << memAccess; 242 } 243 244 if (spirv::bitEnumContainsAll(memAccess.getValue(), 245 spirv::MemoryAccess::Aligned)) { 246 if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) { 247 return memoryOp.emitOpError("missing alignment value"); 248 } 249 } else { 250 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) { 251 return memoryOp.emitOpError( 252 "invalid alignment specification with non-aligned memory access " 253 "specification"); 254 } 255 } 256 return success(); 257 } 258 259 //===----------------------------------------------------------------------===// 260 // spirv.AccessChainOp 261 //===----------------------------------------------------------------------===// 262 263 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { 264 auto ptrType = llvm::dyn_cast<spirv::PointerType>(type); 265 if (!ptrType) { 266 emitError(baseLoc, "'spirv.AccessChain' op expected a pointer " 267 "to composite type, but provided ") 268 << type; 269 return nullptr; 270 } 271 272 auto resultType = ptrType.getPointeeType(); 273 auto resultStorageClass = ptrType.getStorageClass(); 274 int32_t index = 0; 275 276 for (auto indexSSA : indices) { 277 auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType); 278 if (!cType) { 279 emitError( 280 baseLoc, 281 "'spirv.AccessChain' op cannot extract from non-composite type ") 282 << resultType << " with index " << index; 283 return nullptr; 284 } 285 index = 0; 286 if (llvm::isa<spirv::StructType>(resultType)) { 287 Operation *op = indexSSA.getDefiningOp(); 288 if (!op) { 289 emitError(baseLoc, "'spirv.AccessChain' op index must be an " 290 "integer spirv.Constant to access " 291 "element of spirv.struct"); 292 return nullptr; 293 } 294 295 // TODO: this should be relaxed to allow 296 // integer literals of other bitwidths. 297 if (failed(spirv::extractValueFromConstOp(op, index))) { 298 emitError( 299 baseLoc, 300 "'spirv.AccessChain' index must be an integer spirv.Constant to " 301 "access element of spirv.struct, but provided ") 302 << op->getName(); 303 return nullptr; 304 } 305 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) { 306 emitError(baseLoc, "'spirv.AccessChain' op index ") 307 << index << " out of bounds for " << resultType; 308 return nullptr; 309 } 310 } 311 resultType = cType.getElementType(index); 312 } 313 return spirv::PointerType::get(resultType, resultStorageClass); 314 } 315 316 void AccessChainOp::build(OpBuilder &builder, OperationState &state, 317 Value basePtr, ValueRange indices) { 318 auto type = getElementPtrType(basePtr.getType(), indices, state.location); 319 assert(type && "Unable to deduce return type based on basePtr and indices"); 320 build(builder, state, type, basePtr, indices); 321 } 322 323 template <typename Op> 324 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { 325 printer << ' ' << op.getBasePtr() << '[' << indices 326 << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes(); 327 } 328 329 template <typename Op> 330 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { 331 auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(), 332 indices, accessChainOp.getLoc()); 333 if (!resultType) 334 return failure(); 335 336 auto providedResultType = 337 llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType()); 338 if (!providedResultType) 339 return accessChainOp.emitOpError( 340 "result type must be a pointer, but provided") 341 << providedResultType; 342 343 if (resultType != providedResultType) 344 return accessChainOp.emitOpError("invalid result type: expected ") 345 << resultType << ", but provided " << providedResultType; 346 347 return success(); 348 } 349 350 LogicalResult AccessChainOp::verify() { 351 return verifyAccessChain(*this, getIndices()); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // spirv.LoadOp 356 //===----------------------------------------------------------------------===// 357 358 void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr, 359 MemoryAccessAttr memoryAccess, IntegerAttr alignment) { 360 auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType()); 361 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, 362 alignment); 363 } 364 365 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { 366 // Parse the storage class specification 367 spirv::StorageClass storageClass; 368 OpAsmParser::UnresolvedOperand ptrInfo; 369 Type elementType; 370 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || 371 parseMemoryAccessAttributes<LoadOp>(parser, result) || 372 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 373 parser.parseType(elementType)) { 374 return failure(); 375 } 376 377 auto ptrType = spirv::PointerType::get(elementType, storageClass); 378 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { 379 return failure(); 380 } 381 382 result.addTypes(elementType); 383 return success(); 384 } 385 386 void LoadOp::print(OpAsmPrinter &printer) { 387 SmallVector<StringRef, 4> elidedAttrs; 388 StringRef sc = stringifyStorageClass( 389 llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass()); 390 printer << " \"" << sc << "\" " << getPtr(); 391 392 printMemoryAccessAttribute(*this, printer, elidedAttrs); 393 394 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); 395 printer << " : " << getType(); 396 } 397 398 LogicalResult LoadOp::verify() { 399 // SPIR-V spec : "Result Type is the type of the loaded object. It must be a 400 // type with fixed size; i.e., it cannot be, nor include, any 401 // OpTypeRuntimeArray types." 402 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) { 403 return failure(); 404 } 405 return verifyMemoryAccessAttribute(*this); 406 } 407 408 //===----------------------------------------------------------------------===// 409 // spirv.StoreOp 410 //===----------------------------------------------------------------------===// 411 412 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { 413 // Parse the storage class specification 414 spirv::StorageClass storageClass; 415 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; 416 auto loc = parser.getCurrentLocation(); 417 Type elementType; 418 if (parseEnumStrAttr(storageClass, parser) || 419 parser.parseOperandList(operandInfo, 2) || 420 parseMemoryAccessAttributes<StoreOp>(parser, result) || 421 parser.parseColon() || parser.parseType(elementType)) { 422 return failure(); 423 } 424 425 auto ptrType = spirv::PointerType::get(elementType, storageClass); 426 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, 427 result.operands)) { 428 return failure(); 429 } 430 return success(); 431 } 432 433 void StoreOp::print(OpAsmPrinter &printer) { 434 SmallVector<StringRef, 4> elidedAttrs; 435 StringRef sc = stringifyStorageClass( 436 llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass()); 437 printer << " \"" << sc << "\" " << getPtr() << ", " << getValue(); 438 439 printMemoryAccessAttribute(*this, printer, elidedAttrs); 440 441 printer << " : " << getValue().getType(); 442 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); 443 } 444 445 LogicalResult StoreOp::verify() { 446 // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an 447 // OpTypePointer whose Type operand is the same as the type of Object." 448 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) 449 return failure(); 450 return verifyMemoryAccessAttribute(*this); 451 } 452 453 //===----------------------------------------------------------------------===// 454 // spirv.CopyMemory 455 //===----------------------------------------------------------------------===// 456 457 void CopyMemoryOp::print(OpAsmPrinter &printer) { 458 printer << ' '; 459 460 StringRef targetStorageClass = stringifyStorageClass( 461 llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass()); 462 printer << " \"" << targetStorageClass << "\" " << getTarget() << ", "; 463 464 StringRef sourceStorageClass = stringifyStorageClass( 465 llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass()); 466 printer << " \"" << sourceStorageClass << "\" " << getSource(); 467 468 SmallVector<StringRef, 4> elidedAttrs; 469 printMemoryAccessAttribute(*this, printer, elidedAttrs); 470 printSourceMemoryAccessAttribute(*this, printer, elidedAttrs, 471 getSourceMemoryAccess(), 472 getSourceAlignment()); 473 474 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); 475 476 Type pointeeType = 477 llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType(); 478 printer << " : " << pointeeType; 479 } 480 481 ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) { 482 spirv::StorageClass targetStorageClass; 483 OpAsmParser::UnresolvedOperand targetPtrInfo; 484 485 spirv::StorageClass sourceStorageClass; 486 OpAsmParser::UnresolvedOperand sourcePtrInfo; 487 488 Type elementType; 489 490 if (parseEnumStrAttr(targetStorageClass, parser) || 491 parser.parseOperand(targetPtrInfo) || parser.parseComma() || 492 parseEnumStrAttr(sourceStorageClass, parser) || 493 parser.parseOperand(sourcePtrInfo) || 494 parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) { 495 return failure(); 496 } 497 498 if (!parser.parseOptionalComma()) { 499 // Parse 2nd memory access attributes. 500 if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) { 501 return failure(); 502 } 503 } 504 505 if (parser.parseColon() || parser.parseType(elementType)) 506 return failure(); 507 508 if (parser.parseOptionalAttrDict(result.attributes)) 509 return failure(); 510 511 auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); 512 auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); 513 514 if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) || 515 parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) { 516 return failure(); 517 } 518 519 return success(); 520 } 521 522 LogicalResult CopyMemoryOp::verify() { 523 Type targetType = 524 llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType(); 525 526 Type sourceType = 527 llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType(); 528 529 if (targetType != sourceType) 530 return emitOpError("both operands must be pointers to the same type"); 531 532 if (failed(verifyMemoryAccessAttribute(*this))) 533 return failure(); 534 535 // TODO - According to the spec: 536 // 537 // If two masks are present, the first applies to Target and cannot include 538 // MakePointerVisible, and the second applies to Source and cannot include 539 // MakePointerAvailable. 540 // 541 // Add such verification here. 542 543 return verifySourceMemoryAccessAttribute(*this); 544 } 545 546 //===----------------------------------------------------------------------===// 547 // spirv.InBoundsPtrAccessChainOp 548 //===----------------------------------------------------------------------===// 549 550 void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state, 551 Value basePtr, Value element, 552 ValueRange indices) { 553 auto type = getElementPtrType(basePtr.getType(), indices, state.location); 554 assert(type && "Unable to deduce return type based on basePtr and indices"); 555 build(builder, state, type, basePtr, element, indices); 556 } 557 558 LogicalResult InBoundsPtrAccessChainOp::verify() { 559 return verifyAccessChain(*this, getIndices()); 560 } 561 562 //===----------------------------------------------------------------------===// 563 // spirv.PtrAccessChainOp 564 //===----------------------------------------------------------------------===// 565 566 void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state, 567 Value basePtr, Value element, ValueRange indices) { 568 auto type = getElementPtrType(basePtr.getType(), indices, state.location); 569 assert(type && "Unable to deduce return type based on basePtr and indices"); 570 build(builder, state, type, basePtr, element, indices); 571 } 572 573 LogicalResult PtrAccessChainOp::verify() { 574 return verifyAccessChain(*this, getIndices()); 575 } 576 577 //===----------------------------------------------------------------------===// 578 // spirv.Variable 579 //===----------------------------------------------------------------------===// 580 581 ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) { 582 // Parse optional initializer 583 std::optional<OpAsmParser::UnresolvedOperand> initInfo; 584 if (succeeded(parser.parseOptionalKeyword("init"))) { 585 initInfo = OpAsmParser::UnresolvedOperand(); 586 if (parser.parseLParen() || parser.parseOperand(*initInfo) || 587 parser.parseRParen()) 588 return failure(); 589 } 590 591 if (parseVariableDecorations(parser, result)) { 592 return failure(); 593 } 594 595 // Parse result pointer type 596 Type type; 597 if (parser.parseColon()) 598 return failure(); 599 auto loc = parser.getCurrentLocation(); 600 if (parser.parseType(type)) 601 return failure(); 602 603 auto ptrType = llvm::dyn_cast<spirv::PointerType>(type); 604 if (!ptrType) 605 return parser.emitError(loc, "expected spirv.ptr type"); 606 result.addTypes(ptrType); 607 608 // Resolve the initializer operand 609 if (initInfo) { 610 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), 611 result.operands)) 612 return failure(); 613 } 614 615 auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>( 616 ptrType.getStorageClass()); 617 result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr); 618 619 return success(); 620 } 621 622 void VariableOp::print(OpAsmPrinter &printer) { 623 SmallVector<StringRef, 4> elidedAttrs{ 624 spirv::attributeName<spirv::StorageClass>()}; 625 // Print optional initializer 626 if (getNumOperands() != 0) 627 printer << " init(" << getInitializer() << ")"; 628 629 printVariableDecorations(*this, printer, elidedAttrs); 630 printer << " : " << getType(); 631 } 632 633 LogicalResult VariableOp::verify() { 634 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the 635 // object. It cannot be Generic. It must be the same as the Storage Class 636 // operand of the Result Type." 637 if (getStorageClass() != spirv::StorageClass::Function) { 638 return emitOpError( 639 "can only be used to model function-level variables. Use " 640 "spirv.GlobalVariable for module-level variables."); 641 } 642 643 auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType()); 644 if (getStorageClass() != pointerType.getStorageClass()) 645 return emitOpError( 646 "storage class must match result pointer's storage class"); 647 648 if (getNumOperands() != 0) { 649 // SPIR-V spec: "Initializer must be an <id> from a constant instruction or 650 // a global (module scope) OpVariable instruction". 651 auto *initOp = getOperand(0).getDefiningOp(); 652 if (!initOp || !isa<spirv::ConstantOp, // for normal constant 653 spirv::ReferenceOfOp, // for spec constant 654 spirv::AddressOfOp>(initOp)) 655 return emitOpError("initializer must be the result of a " 656 "constant or spirv.GlobalVariable op"); 657 } 658 659 auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) { 660 return op->getAttr( 661 llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration))); 662 }; 663 664 // TODO: generate these strings using ODS. 665 for (auto decoration : 666 {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding, 667 spirv::Decoration::BuiltIn}) { 668 if (auto attr = getDecorationAttr(decoration)) 669 return emitOpError("cannot have '") 670 << llvm::convertToSnakeFromCamelCase( 671 stringifyDecoration(decoration)) 672 << "' attribute (only allowed in spirv.GlobalVariable)"; 673 } 674 675 // From SPV_KHR_physical_storage_buffer: 676 // > If an OpVariable's pointee type is a pointer (or array of pointers) in 677 // > PhysicalStorageBuffer storage class, then the variable must be decorated 678 // > with exactly one of AliasedPointer or RestrictPointer. 679 auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType()); 680 if (!pointeePtrType) { 681 if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) { 682 pointeePtrType = 683 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType()); 684 } 685 } 686 687 if (pointeePtrType && pointeePtrType.getStorageClass() == 688 spirv::StorageClass::PhysicalStorageBuffer) { 689 bool hasAliasedPtr = 690 getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr; 691 bool hasRestrictPtr = 692 getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr; 693 694 if (!hasAliasedPtr && !hasRestrictPtr) 695 return emitOpError() << " with physical buffer pointer must be decorated " 696 "either 'AliasedPointer' or 'RestrictPointer'"; 697 698 if (hasAliasedPtr && hasRestrictPtr) 699 return emitOpError() 700 << " with physical buffer pointer must have exactly one " 701 "aliasing decoration"; 702 } 703 704 return success(); 705 } 706 707 } // namespace mlir::spirv 708