1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the operations in the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 14 15 #include "SPIRVOpUtils.h" 16 #include "SPIRVParsingUtils.h" 17 18 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 19 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 20 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 21 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" 22 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 23 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/BuiltinTypes.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/IR/OpDefinition.h" 28 #include "mlir/IR/OpImplementation.h" 29 #include "mlir/IR/Operation.h" 30 #include "mlir/IR/TypeUtilities.h" 31 #include "mlir/Interfaces/FunctionImplementation.h" 32 #include "llvm/ADT/APFloat.h" 33 #include "llvm/ADT/APInt.h" 34 #include "llvm/ADT/ArrayRef.h" 35 #include "llvm/ADT/STLExtras.h" 36 #include "llvm/ADT/StringExtras.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 #include <cassert> 39 #include <numeric> 40 #include <optional> 41 #include <type_traits> 42 43 using namespace mlir; 44 using namespace mlir::spirv::AttrNames; 45 46 //===----------------------------------------------------------------------===// 47 // Common utility functions 48 //===----------------------------------------------------------------------===// 49 50 LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) { 51 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op); 52 if (!constOp) { 53 return failure(); 54 } 55 auto valueAttr = constOp.getValue(); 56 auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr); 57 if (!integerValueAttr) { 58 return failure(); 59 } 60 61 if (integerValueAttr.getType().isSignlessInteger()) 62 value = integerValueAttr.getInt(); 63 else 64 value = integerValueAttr.getSInt(); 65 66 return success(); 67 } 68 69 LogicalResult 70 spirv::verifyMemorySemantics(Operation *op, 71 spirv::MemorySemantics memorySemantics) { 72 // According to the SPIR-V specification: 73 // "Despite being a mask and allowing multiple bits to be combined, it is 74 // invalid for more than one of these four bits to be set: Acquire, Release, 75 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and 76 // Release semantics is done by setting the AcquireRelease bit, not by setting 77 // two bits." 78 auto atMostOneInSet = spirv::MemorySemantics::Acquire | 79 spirv::MemorySemantics::Release | 80 spirv::MemorySemantics::AcquireRelease | 81 spirv::MemorySemantics::SequentiallyConsistent; 82 83 auto bitCount = 84 llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet)); 85 if (bitCount > 1) { 86 return op->emitError( 87 "expected at most one of these four memory constraints " 88 "to be set: `Acquire`, `Release`," 89 "`AcquireRelease` or `SequentiallyConsistent`"); 90 } 91 return success(); 92 } 93 94 void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer, 95 SmallVectorImpl<StringRef> &elidedAttrs) { 96 // Print optional descriptor binding 97 auto descriptorSetName = llvm::convertToSnakeFromCamelCase( 98 stringifyDecoration(spirv::Decoration::DescriptorSet)); 99 auto bindingName = llvm::convertToSnakeFromCamelCase( 100 stringifyDecoration(spirv::Decoration::Binding)); 101 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName); 102 auto binding = op->getAttrOfType<IntegerAttr>(bindingName); 103 if (descriptorSet && binding) { 104 elidedAttrs.push_back(descriptorSetName); 105 elidedAttrs.push_back(bindingName); 106 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() 107 << ")"; 108 } 109 110 // Print BuiltIn attribute if present 111 auto builtInName = llvm::convertToSnakeFromCamelCase( 112 stringifyDecoration(spirv::Decoration::BuiltIn)); 113 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) { 114 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; 115 elidedAttrs.push_back(builtInName); 116 } 117 118 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); 119 } 120 121 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, 122 OperationState &result) { 123 SmallVector<OpAsmParser::UnresolvedOperand, 2> ops; 124 Type type; 125 // If the operand list is in-between parentheses, then we have a generic form. 126 // (see the fallback in `printOneResultOp`). 127 SMLoc loc = parser.getCurrentLocation(); 128 if (!parser.parseOptionalLParen()) { 129 if (parser.parseOperandList(ops) || parser.parseRParen() || 130 parser.parseOptionalAttrDict(result.attributes) || 131 parser.parseColon() || parser.parseType(type)) 132 return failure(); 133 auto fnType = llvm::dyn_cast<FunctionType>(type); 134 if (!fnType) { 135 parser.emitError(loc, "expected function type"); 136 return failure(); 137 } 138 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) 139 return failure(); 140 result.addTypes(fnType.getResults()); 141 return success(); 142 } 143 return failure(parser.parseOperandList(ops) || 144 parser.parseOptionalAttrDict(result.attributes) || 145 parser.parseColonType(type) || 146 parser.resolveOperands(ops, type, result.operands) || 147 parser.addTypeToList(type, result.types)); 148 } 149 150 static void printOneResultOp(Operation *op, OpAsmPrinter &p) { 151 assert(op->getNumResults() == 1 && "op should have one result"); 152 153 // If not all the operand and result types are the same, just use the 154 // generic assembly form to avoid omitting information in printing. 155 auto resultType = op->getResult(0).getType(); 156 if (llvm::any_of(op->getOperandTypes(), 157 [&](Type type) { return type != resultType; })) { 158 p.printGenericOp(op, /*printOpName=*/false); 159 return; 160 } 161 162 p << ' '; 163 p.printOperands(op->getOperands()); 164 p.printOptionalAttrDict(op->getAttrs()); 165 // Now we can output only one type for all operands and the result. 166 p << " : " << resultType; 167 } 168 169 template <typename Op> 170 static LogicalResult verifyImageOperands(Op imageOp, 171 spirv::ImageOperandsAttr attr, 172 Operation::operand_range operands) { 173 if (!attr) { 174 if (operands.empty()) 175 return success(); 176 177 return imageOp.emitError("the Image Operands should encode what operands " 178 "follow, as per Image Operands"); 179 } 180 181 // TODO: Add the validation rules for the following Image Operands. 182 spirv::ImageOperands noSupportOperands = 183 spirv::ImageOperands::Bias | spirv::ImageOperands::Lod | 184 spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset | 185 spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets | 186 spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod | 187 spirv::ImageOperands::MakeTexelAvailable | 188 spirv::ImageOperands::MakeTexelVisible | 189 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend; 190 191 if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands)) 192 llvm_unreachable("unimplemented operands of Image Operands"); 193 194 return success(); 195 } 196 197 template <typename BlockReadWriteOpTy> 198 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, 199 Value ptr, Value val) { 200 auto valType = val.getType(); 201 if (auto valVecTy = llvm::dyn_cast<VectorType>(valType)) 202 valType = valVecTy.getElementType(); 203 204 if (valType != 205 llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) { 206 return op.emitOpError("mismatch in result type and pointer type"); 207 } 208 return success(); 209 } 210 211 /// Walks the given type hierarchy with the given indices, potentially down 212 /// to component granularity, to select an element type. Returns null type and 213 /// emits errors with the given loc on failure. 214 static Type 215 getElementType(Type type, ArrayRef<int32_t> indices, 216 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { 217 if (indices.empty()) { 218 emitErrorFn("expected at least one index for spirv.CompositeExtract"); 219 return nullptr; 220 } 221 222 for (auto index : indices) { 223 if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) { 224 if (cType.hasCompileTimeKnownNumElements() && 225 (index < 0 || 226 static_cast<uint64_t>(index) >= cType.getNumElements())) { 227 emitErrorFn("index ") << index << " out of bounds for " << type; 228 return nullptr; 229 } 230 type = cType.getElementType(index); 231 } else { 232 emitErrorFn("cannot extract from non-composite type ") 233 << type << " with index " << index; 234 return nullptr; 235 } 236 } 237 return type; 238 } 239 240 static Type 241 getElementType(Type type, Attribute indices, 242 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { 243 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices); 244 if (!indicesArrayAttr) { 245 emitErrorFn("expected a 32-bit integer array attribute for 'indices'"); 246 return nullptr; 247 } 248 if (indicesArrayAttr.empty()) { 249 emitErrorFn("expected at least one index for spirv.CompositeExtract"); 250 return nullptr; 251 } 252 253 SmallVector<int32_t, 2> indexVals; 254 for (auto indexAttr : indicesArrayAttr) { 255 auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr); 256 if (!indexIntAttr) { 257 emitErrorFn("expected an 32-bit integer for index, but found '") 258 << indexAttr << "'"; 259 return nullptr; 260 } 261 indexVals.push_back(indexIntAttr.getInt()); 262 } 263 return getElementType(type, indexVals, emitErrorFn); 264 } 265 266 static Type getElementType(Type type, Attribute indices, Location loc) { 267 auto errorFn = [&](StringRef err) -> InFlightDiagnostic { 268 return ::mlir::emitError(loc, err); 269 }; 270 return getElementType(type, indices, errorFn); 271 } 272 273 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser, 274 SMLoc loc) { 275 auto errorFn = [&](StringRef err) -> InFlightDiagnostic { 276 return parser.emitError(loc, err); 277 }; 278 return getElementType(type, indices, errorFn); 279 } 280 281 template <typename ExtendedBinaryOp> 282 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) { 283 auto resultType = llvm::cast<spirv::StructType>(op.getType()); 284 if (resultType.getNumElements() != 2) 285 return op.emitOpError("expected result struct type containing two members"); 286 287 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(), 288 resultType.getElementType(0), 289 resultType.getElementType(1)})) 290 return op.emitOpError( 291 "expected all operand types and struct member types are the same"); 292 293 return success(); 294 } 295 296 static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, 297 OperationState &result) { 298 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; 299 if (parser.parseOptionalAttrDict(result.attributes) || 300 parser.parseOperandList(operands) || parser.parseColon()) 301 return failure(); 302 303 Type resultType; 304 SMLoc loc = parser.getCurrentLocation(); 305 if (parser.parseType(resultType)) 306 return failure(); 307 308 auto structType = llvm::dyn_cast<spirv::StructType>(resultType); 309 if (!structType || structType.getNumElements() != 2) 310 return parser.emitError(loc, "expected spirv.struct type with two members"); 311 312 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0)); 313 if (parser.resolveOperands(operands, operandTypes, loc, result.operands)) 314 return failure(); 315 316 result.addTypes(resultType); 317 return success(); 318 } 319 320 static void printArithmeticExtendedBinaryOp(Operation *op, 321 OpAsmPrinter &printer) { 322 printer << ' '; 323 printer.printOptionalAttrDict(op->getAttrs()); 324 printer.printOperands(op->getOperands()); 325 printer << " : " << op->getResultTypes().front(); 326 } 327 328 static LogicalResult verifyShiftOp(Operation *op) { 329 if (op->getOperand(0).getType() != op->getResult(0).getType()) { 330 return op->emitError("expected the same type for the first operand and " 331 "result, but provided ") 332 << op->getOperand(0).getType() << " and " 333 << op->getResult(0).getType(); 334 } 335 return success(); 336 } 337 338 //===----------------------------------------------------------------------===// 339 // spirv.mlir.addressof 340 //===----------------------------------------------------------------------===// 341 342 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state, 343 spirv::GlobalVariableOp var) { 344 build(builder, state, var.getType(), SymbolRefAttr::get(var)); 345 } 346 347 LogicalResult spirv::AddressOfOp::verify() { 348 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>( 349 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), 350 getVariableAttr())); 351 if (!varOp) { 352 return emitOpError("expected spirv.GlobalVariable symbol"); 353 } 354 if (getPointer().getType() != varOp.getType()) { 355 return emitOpError( 356 "result type mismatch with the referenced global variable's type"); 357 } 358 return success(); 359 } 360 361 //===----------------------------------------------------------------------===// 362 // spirv.CompositeConstruct 363 //===----------------------------------------------------------------------===// 364 365 LogicalResult spirv::CompositeConstructOp::verify() { 366 operand_range constituents = this->getConstituents(); 367 368 // There are 4 cases with varying verification rules: 369 // 1. Cooperative Matrices (1 constituent) 370 // 2. Structs (1 constituent for each member) 371 // 3. Arrays (1 constituent for each array element) 372 // 4. Vectors (1 constituent (sub-)element for each vector element) 373 374 auto coopElementType = 375 llvm::TypeSwitch<Type, Type>(getType()) 376 .Case<spirv::CooperativeMatrixType>( 377 [](auto coopType) { return coopType.getElementType(); }) 378 .Default([](Type) { return nullptr; }); 379 380 // Case 1. -- matrices. 381 if (coopElementType) { 382 if (constituents.size() != 1) 383 return emitOpError("has incorrect number of operands: expected ") 384 << "1, but provided " << constituents.size(); 385 if (coopElementType != constituents.front().getType()) 386 return emitOpError("operand type mismatch: expected operand type ") 387 << coopElementType << ", but provided " 388 << constituents.front().getType(); 389 return success(); 390 } 391 392 // Case 2./3./4. -- number of constituents matches the number of elements. 393 auto cType = llvm::cast<spirv::CompositeType>(getType()); 394 if (constituents.size() == cType.getNumElements()) { 395 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { 396 if (constituents[index].getType() != cType.getElementType(index)) { 397 return emitOpError("operand type mismatch: expected operand type ") 398 << cType.getElementType(index) << ", but provided " 399 << constituents[index].getType(); 400 } 401 } 402 return success(); 403 } 404 405 // Case 4. -- check that all constituents add up tp the expected vector type. 406 auto resultType = llvm::dyn_cast<VectorType>(cType); 407 if (!resultType) 408 return emitOpError( 409 "expected to return a vector or cooperative matrix when the number of " 410 "constituents is less than what the result needs"); 411 412 SmallVector<unsigned> sizes; 413 for (Value component : constituents) { 414 if (!llvm::isa<VectorType>(component.getType()) && 415 !component.getType().isIntOrFloat()) 416 return emitOpError("operand type mismatch: expected operand to have " 417 "a scalar or vector type, but provided ") 418 << component.getType(); 419 420 Type elementType = component.getType(); 421 if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) { 422 sizes.push_back(vectorType.getNumElements()); 423 elementType = vectorType.getElementType(); 424 } else { 425 sizes.push_back(1); 426 } 427 428 if (elementType != resultType.getElementType()) 429 return emitOpError("operand element type mismatch: expected to be ") 430 << resultType.getElementType() << ", but provided " << elementType; 431 } 432 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0); 433 if (totalCount != cType.getNumElements()) 434 return emitOpError("has incorrect number of operands: expected ") 435 << cType.getNumElements() << ", but provided " << totalCount; 436 return success(); 437 } 438 439 //===----------------------------------------------------------------------===// 440 // spirv.CompositeExtractOp 441 //===----------------------------------------------------------------------===// 442 443 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state, 444 Value composite, 445 ArrayRef<int32_t> indices) { 446 auto indexAttr = builder.getI32ArrayAttr(indices); 447 auto elementType = 448 getElementType(composite.getType(), indexAttr, state.location); 449 if (!elementType) { 450 return; 451 } 452 build(builder, state, elementType, composite, indexAttr); 453 } 454 455 ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser, 456 OperationState &result) { 457 OpAsmParser::UnresolvedOperand compositeInfo; 458 Attribute indicesAttr; 459 StringRef indicesAttrName = 460 spirv::CompositeExtractOp::getIndicesAttrName(result.name); 461 Type compositeType; 462 SMLoc attrLocation; 463 464 if (parser.parseOperand(compositeInfo) || 465 parser.getCurrentLocation(&attrLocation) || 466 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) || 467 parser.parseColonType(compositeType) || 468 parser.resolveOperand(compositeInfo, compositeType, result.operands)) { 469 return failure(); 470 } 471 472 Type resultType = 473 getElementType(compositeType, indicesAttr, parser, attrLocation); 474 if (!resultType) { 475 return failure(); 476 } 477 result.addTypes(resultType); 478 return success(); 479 } 480 481 void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) { 482 printer << ' ' << getComposite() << getIndices() << " : " 483 << getComposite().getType(); 484 } 485 486 LogicalResult spirv::CompositeExtractOp::verify() { 487 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices()); 488 auto resultType = 489 getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); 490 if (!resultType) 491 return failure(); 492 493 if (resultType != getType()) { 494 return emitOpError("invalid result type: expected ") 495 << resultType << " but provided " << getType(); 496 } 497 498 return success(); 499 } 500 501 //===----------------------------------------------------------------------===// 502 // spirv.CompositeInsert 503 //===----------------------------------------------------------------------===// 504 505 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state, 506 Value object, Value composite, 507 ArrayRef<int32_t> indices) { 508 auto indexAttr = builder.getI32ArrayAttr(indices); 509 build(builder, state, composite.getType(), object, composite, indexAttr); 510 } 511 512 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser, 513 OperationState &result) { 514 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; 515 Type objectType, compositeType; 516 Attribute indicesAttr; 517 StringRef indicesAttrName = 518 spirv::CompositeInsertOp::getIndicesAttrName(result.name); 519 auto loc = parser.getCurrentLocation(); 520 521 return failure( 522 parser.parseOperandList(operands, 2) || 523 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) || 524 parser.parseColonType(objectType) || 525 parser.parseKeywordType("into", compositeType) || 526 parser.resolveOperands(operands, {objectType, compositeType}, loc, 527 result.operands) || 528 parser.addTypesToList(compositeType, result.types)); 529 } 530 531 LogicalResult spirv::CompositeInsertOp::verify() { 532 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices()); 533 auto objectType = 534 getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); 535 if (!objectType) 536 return failure(); 537 538 if (objectType != getObject().getType()) { 539 return emitOpError("object operand type should be ") 540 << objectType << ", but found " << getObject().getType(); 541 } 542 543 if (getComposite().getType() != getType()) { 544 return emitOpError("result type should be the same as " 545 "the composite type, but found ") 546 << getComposite().getType() << " vs " << getType(); 547 } 548 549 return success(); 550 } 551 552 void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) { 553 printer << " " << getObject() << ", " << getComposite() << getIndices() 554 << " : " << getObject().getType() << " into " 555 << getComposite().getType(); 556 } 557 558 //===----------------------------------------------------------------------===// 559 // spirv.Constant 560 //===----------------------------------------------------------------------===// 561 562 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, 563 OperationState &result) { 564 Attribute value; 565 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name); 566 if (parser.parseAttribute(value, valueAttrName, result.attributes)) 567 return failure(); 568 569 Type type = NoneType::get(parser.getContext()); 570 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value)) 571 type = typedAttr.getType(); 572 if (llvm::isa<NoneType, TensorType>(type)) { 573 if (parser.parseColonType(type)) 574 return failure(); 575 } 576 577 return parser.addTypeToList(type, result.types); 578 } 579 580 void spirv::ConstantOp::print(OpAsmPrinter &printer) { 581 printer << ' ' << getValue(); 582 if (llvm::isa<spirv::ArrayType>(getType())) 583 printer << " : " << getType(); 584 } 585 586 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, 587 Type opType) { 588 if (llvm::isa<IntegerAttr, FloatAttr>(value)) { 589 auto valueType = llvm::cast<TypedAttr>(value).getType(); 590 if (valueType != opType) 591 return op.emitOpError("result type (") 592 << opType << ") does not match value type (" << valueType << ")"; 593 return success(); 594 } 595 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) { 596 auto valueType = llvm::cast<TypedAttr>(value).getType(); 597 if (valueType == opType) 598 return success(); 599 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType); 600 auto shapedType = llvm::dyn_cast<ShapedType>(valueType); 601 if (!arrayType) 602 return op.emitOpError("result or element type (") 603 << opType << ") does not match value type (" << valueType 604 << "), must be the same or spirv.array"; 605 606 int numElements = arrayType.getNumElements(); 607 auto opElemType = arrayType.getElementType(); 608 while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) { 609 numElements *= t.getNumElements(); 610 opElemType = t.getElementType(); 611 } 612 if (!opElemType.isIntOrFloat()) 613 return op.emitOpError("only support nested array result type"); 614 615 auto valueElemType = shapedType.getElementType(); 616 if (valueElemType != opElemType) { 617 return op.emitOpError("result element type (") 618 << opElemType << ") does not match value element type (" 619 << valueElemType << ")"; 620 } 621 622 if (numElements != shapedType.getNumElements()) { 623 return op.emitOpError("result number of elements (") 624 << numElements << ") does not match value number of elements (" 625 << shapedType.getNumElements() << ")"; 626 } 627 return success(); 628 } 629 if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) { 630 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType); 631 if (!arrayType) 632 return op.emitOpError( 633 "must have spirv.array result type for array value"); 634 Type elemType = arrayType.getElementType(); 635 for (Attribute element : arrayAttr.getValue()) { 636 // Verify array elements recursively. 637 if (failed(verifyConstantType(op, element, elemType))) 638 return failure(); 639 } 640 return success(); 641 } 642 return op.emitOpError("cannot have attribute: ") << value; 643 } 644 645 LogicalResult spirv::ConstantOp::verify() { 646 // ODS already generates checks to make sure the result type is valid. We just 647 // need to additionally check that the value's attribute type is consistent 648 // with the result type. 649 return verifyConstantType(*this, getValueAttr(), getType()); 650 } 651 652 bool spirv::ConstantOp::isBuildableWith(Type type) { 653 // Must be valid SPIR-V type first. 654 if (!llvm::isa<spirv::SPIRVType>(type)) 655 return false; 656 657 if (isa<SPIRVDialect>(type.getDialect())) { 658 // TODO: support constant struct 659 return llvm::isa<spirv::ArrayType>(type); 660 } 661 662 return true; 663 } 664 665 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, 666 OpBuilder &builder) { 667 if (auto intType = llvm::dyn_cast<IntegerType>(type)) { 668 unsigned width = intType.getWidth(); 669 if (width == 1) 670 return builder.create<spirv::ConstantOp>(loc, type, 671 builder.getBoolAttr(false)); 672 return builder.create<spirv::ConstantOp>( 673 loc, type, builder.getIntegerAttr(type, APInt(width, 0))); 674 } 675 if (auto floatType = llvm::dyn_cast<FloatType>(type)) { 676 return builder.create<spirv::ConstantOp>( 677 loc, type, builder.getFloatAttr(floatType, 0.0)); 678 } 679 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) { 680 Type elemType = vectorType.getElementType(); 681 if (llvm::isa<IntegerType>(elemType)) { 682 return builder.create<spirv::ConstantOp>( 683 loc, type, 684 DenseElementsAttr::get(vectorType, 685 IntegerAttr::get(elemType, 0).getValue())); 686 } 687 if (llvm::isa<FloatType>(elemType)) { 688 return builder.create<spirv::ConstantOp>( 689 loc, type, 690 DenseFPElementsAttr::get(vectorType, 691 FloatAttr::get(elemType, 0.0).getValue())); 692 } 693 } 694 695 llvm_unreachable("unimplemented types for ConstantOp::getZero()"); 696 } 697 698 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, 699 OpBuilder &builder) { 700 if (auto intType = llvm::dyn_cast<IntegerType>(type)) { 701 unsigned width = intType.getWidth(); 702 if (width == 1) 703 return builder.create<spirv::ConstantOp>(loc, type, 704 builder.getBoolAttr(true)); 705 return builder.create<spirv::ConstantOp>( 706 loc, type, builder.getIntegerAttr(type, APInt(width, 1))); 707 } 708 if (auto floatType = llvm::dyn_cast<FloatType>(type)) { 709 return builder.create<spirv::ConstantOp>( 710 loc, type, builder.getFloatAttr(floatType, 1.0)); 711 } 712 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) { 713 Type elemType = vectorType.getElementType(); 714 if (llvm::isa<IntegerType>(elemType)) { 715 return builder.create<spirv::ConstantOp>( 716 loc, type, 717 DenseElementsAttr::get(vectorType, 718 IntegerAttr::get(elemType, 1).getValue())); 719 } 720 if (llvm::isa<FloatType>(elemType)) { 721 return builder.create<spirv::ConstantOp>( 722 loc, type, 723 DenseFPElementsAttr::get(vectorType, 724 FloatAttr::get(elemType, 1.0).getValue())); 725 } 726 } 727 728 llvm_unreachable("unimplemented types for ConstantOp::getOne()"); 729 } 730 731 void mlir::spirv::ConstantOp::getAsmResultNames( 732 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) { 733 Type type = getType(); 734 735 SmallString<32> specialNameBuffer; 736 llvm::raw_svector_ostream specialName(specialNameBuffer); 737 specialName << "cst"; 738 739 IntegerType intTy = llvm::dyn_cast<IntegerType>(type); 740 741 if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) { 742 if (intTy && intTy.getWidth() == 1) { 743 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 744 } 745 746 if (intTy.isSignless()) { 747 specialName << intCst.getInt(); 748 } else if (intTy.isUnsigned()) { 749 specialName << intCst.getUInt(); 750 } else { 751 specialName << intCst.getSInt(); 752 } 753 } 754 755 if (intTy || llvm::isa<FloatType>(type)) { 756 specialName << '_' << type; 757 } 758 759 if (auto vecType = llvm::dyn_cast<VectorType>(type)) { 760 specialName << "_vec_"; 761 specialName << vecType.getDimSize(0); 762 763 Type elementType = vecType.getElementType(); 764 765 if (llvm::isa<IntegerType>(elementType) || 766 llvm::isa<FloatType>(elementType)) { 767 specialName << "x" << elementType; 768 } 769 } 770 771 setNameFn(getResult(), specialName.str()); 772 } 773 774 void mlir::spirv::AddressOfOp::getAsmResultNames( 775 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) { 776 SmallString<32> specialNameBuffer; 777 llvm::raw_svector_ostream specialName(specialNameBuffer); 778 specialName << getVariable() << "_addr"; 779 setNameFn(getResult(), specialName.str()); 780 } 781 782 //===----------------------------------------------------------------------===// 783 // spirv.ControlBarrierOp 784 //===----------------------------------------------------------------------===// 785 786 LogicalResult spirv::ControlBarrierOp::verify() { 787 return verifyMemorySemantics(getOperation(), getMemorySemantics()); 788 } 789 790 //===----------------------------------------------------------------------===// 791 // spirv.EntryPoint 792 //===----------------------------------------------------------------------===// 793 794 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state, 795 spirv::ExecutionModel executionModel, 796 spirv::FuncOp function, 797 ArrayRef<Attribute> interfaceVars) { 798 build(builder, state, 799 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel), 800 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars)); 801 } 802 803 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser, 804 OperationState &result) { 805 spirv::ExecutionModel execModel; 806 SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers; 807 SmallVector<Type, 0> idTypes; 808 SmallVector<Attribute, 4> interfaceVars; 809 810 FlatSymbolRefAttr fn; 811 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) || 812 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) { 813 return failure(); 814 } 815 816 if (!parser.parseOptionalComma()) { 817 // Parse the interface variables 818 if (parser.parseCommaSeparatedList([&]() -> ParseResult { 819 // The name of the interface variable attribute isnt important 820 FlatSymbolRefAttr var; 821 NamedAttrList attrs; 822 if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) 823 return failure(); 824 interfaceVars.push_back(var); 825 return success(); 826 })) 827 return failure(); 828 } 829 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name), 830 parser.getBuilder().getArrayAttr(interfaceVars)); 831 return success(); 832 } 833 834 void spirv::EntryPointOp::print(OpAsmPrinter &printer) { 835 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" "; 836 printer.printSymbolName(getFn()); 837 auto interfaceVars = getInterface().getValue(); 838 if (!interfaceVars.empty()) { 839 printer << ", "; 840 llvm::interleaveComma(interfaceVars, printer); 841 } 842 } 843 844 LogicalResult spirv::EntryPointOp::verify() { 845 // Checks for fn and interface symbol reference are done in spirv::ModuleOp 846 // verification. 847 return success(); 848 } 849 850 //===----------------------------------------------------------------------===// 851 // spirv.ExecutionMode 852 //===----------------------------------------------------------------------===// 853 854 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state, 855 spirv::FuncOp function, 856 spirv::ExecutionMode executionMode, 857 ArrayRef<int32_t> params) { 858 build(builder, state, SymbolRefAttr::get(function), 859 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode), 860 builder.getI32ArrayAttr(params)); 861 } 862 863 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser, 864 OperationState &result) { 865 spirv::ExecutionMode execMode; 866 Attribute fn; 867 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) || 868 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) { 869 return failure(); 870 } 871 872 SmallVector<int32_t, 4> values; 873 Type i32Type = parser.getBuilder().getIntegerType(32); 874 while (!parser.parseOptionalComma()) { 875 NamedAttrList attr; 876 Attribute value; 877 if (parser.parseAttribute(value, i32Type, "value", attr)) { 878 return failure(); 879 } 880 values.push_back(llvm::cast<IntegerAttr>(value).getInt()); 881 } 882 StringRef valuesAttrName = 883 spirv::ExecutionModeOp::getValuesAttrName(result.name); 884 result.addAttribute(valuesAttrName, 885 parser.getBuilder().getI32ArrayAttr(values)); 886 return success(); 887 } 888 889 void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) { 890 printer << " "; 891 printer.printSymbolName(getFn()); 892 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\""; 893 auto values = this->getValues(); 894 if (values.empty()) 895 return; 896 printer << ", "; 897 llvm::interleaveComma(values, printer, [&](Attribute a) { 898 printer << llvm::cast<IntegerAttr>(a).getInt(); 899 }); 900 } 901 902 //===----------------------------------------------------------------------===// 903 // spirv.func 904 //===----------------------------------------------------------------------===// 905 906 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { 907 SmallVector<OpAsmParser::Argument> entryArgs; 908 SmallVector<DictionaryAttr> resultAttrs; 909 SmallVector<Type> resultTypes; 910 auto &builder = parser.getBuilder(); 911 912 // Parse the name as a symbol. 913 StringAttr nameAttr; 914 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 915 result.attributes)) 916 return failure(); 917 918 // Parse the function signature. 919 bool isVariadic = false; 920 if (function_interface_impl::parseFunctionSignature( 921 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, 922 resultAttrs)) 923 return failure(); 924 925 SmallVector<Type> argTypes; 926 for (auto &arg : entryArgs) 927 argTypes.push_back(arg.type); 928 auto fnType = builder.getFunctionType(argTypes, resultTypes); 929 result.addAttribute(getFunctionTypeAttrName(result.name), 930 TypeAttr::get(fnType)); 931 932 // Parse the optional function control keyword. 933 spirv::FunctionControl fnControl; 934 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result)) 935 return failure(); 936 937 // If additional attributes are present, parse them. 938 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 939 return failure(); 940 941 // Add the attributes to the function arguments. 942 assert(resultAttrs.size() == resultTypes.size()); 943 function_interface_impl::addArgAndResultAttrs( 944 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), 945 getResAttrsAttrName(result.name)); 946 947 // Parse the optional function body. 948 auto *body = result.addRegion(); 949 OptionalParseResult parseResult = 950 parser.parseOptionalRegion(*body, entryArgs); 951 return failure(parseResult.has_value() && failed(*parseResult)); 952 } 953 954 void spirv::FuncOp::print(OpAsmPrinter &printer) { 955 // Print function name, signature, and control. 956 printer << " "; 957 printer.printSymbolName(getSymName()); 958 auto fnType = getFunctionType(); 959 function_interface_impl::printFunctionSignature( 960 printer, *this, fnType.getInputs(), 961 /*isVariadic=*/false, fnType.getResults()); 962 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl()) 963 << "\""; 964 function_interface_impl::printFunctionAttributes( 965 printer, *this, 966 {spirv::attributeName<spirv::FunctionControl>(), 967 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), 968 getFunctionControlAttrName()}); 969 970 // Print the body if this is not an external function. 971 Region &body = this->getBody(); 972 if (!body.empty()) { 973 printer << ' '; 974 printer.printRegion(body, /*printEntryBlockArgs=*/false, 975 /*printBlockTerminators=*/true); 976 } 977 } 978 979 LogicalResult spirv::FuncOp::verifyType() { 980 FunctionType fnType = getFunctionType(); 981 if (fnType.getNumResults() > 1) 982 return emitOpError("cannot have more than one result"); 983 984 auto hasDecorationAttr = [&](spirv::Decoration decoration, 985 unsigned argIndex) { 986 auto func = llvm::cast<FunctionOpInterface>(getOperation()); 987 for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) { 988 if (argAttr.getName() != spirv::DecorationAttr::name) 989 continue; 990 if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue())) 991 return decAttr.getValue() == decoration; 992 } 993 return false; 994 }; 995 996 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) { 997 Type param = fnType.getInputs()[i]; 998 auto inputPtrType = dyn_cast<spirv::PointerType>(param); 999 if (!inputPtrType) 1000 continue; 1001 1002 auto pointeePtrType = 1003 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType()); 1004 if (pointeePtrType) { 1005 // SPIR-V spec, from SPV_KHR_physical_storage_buffer: 1006 // > If an OpFunctionParameter is a pointer (or contains a pointer) 1007 // > and the type it points to is a pointer in the PhysicalStorageBuffer 1008 // > storage class, the function parameter must be decorated with exactly 1009 // > one of AliasedPointer or RestrictPointer. 1010 if (pointeePtrType.getStorageClass() != 1011 spirv::StorageClass::PhysicalStorageBuffer) 1012 continue; 1013 1014 bool hasAliasedPtr = 1015 hasDecorationAttr(spirv::Decoration::AliasedPointer, i); 1016 bool hasRestrictPtr = 1017 hasDecorationAttr(spirv::Decoration::RestrictPointer, i); 1018 if (!hasAliasedPtr && !hasRestrictPtr) 1019 return emitOpError() 1020 << "with a pointer points to a physical buffer pointer must " 1021 "be decorated either 'AliasedPointer' or 'RestrictPointer'"; 1022 continue; 1023 } 1024 // SPIR-V spec, from SPV_KHR_physical_storage_buffer: 1025 // > If an OpFunctionParameter is a pointer (or contains a pointer) in 1026 // > the PhysicalStorageBuffer storage class, the function parameter must 1027 // > be decorated with exactly one of Aliased or Restrict. 1028 if (auto pointeeArrayType = 1029 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) { 1030 pointeePtrType = 1031 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType()); 1032 } else { 1033 pointeePtrType = inputPtrType; 1034 } 1035 1036 if (!pointeePtrType || pointeePtrType.getStorageClass() != 1037 spirv::StorageClass::PhysicalStorageBuffer) 1038 continue; 1039 1040 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i); 1041 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i); 1042 if (!hasAliased && !hasRestrict) 1043 return emitOpError() << "with physical buffer pointer must be decorated " 1044 "either 'Aliased' or 'Restrict'"; 1045 } 1046 1047 return success(); 1048 } 1049 1050 LogicalResult spirv::FuncOp::verifyBody() { 1051 FunctionType fnType = getFunctionType(); 1052 1053 auto walkResult = walk([fnType](Operation *op) -> WalkResult { 1054 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) { 1055 if (fnType.getNumResults() != 0) 1056 return retOp.emitOpError("cannot be used in functions returning value"); 1057 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) { 1058 if (fnType.getNumResults() != 1) 1059 return retOp.emitOpError( 1060 "returns 1 value but enclosing function requires ") 1061 << fnType.getNumResults() << " results"; 1062 1063 auto retOperandType = retOp.getValue().getType(); 1064 auto fnResultType = fnType.getResult(0); 1065 if (retOperandType != fnResultType) 1066 return retOp.emitOpError(" return value's type (") 1067 << retOperandType << ") mismatch with function's result type (" 1068 << fnResultType << ")"; 1069 } 1070 return WalkResult::advance(); 1071 }); 1072 1073 // TODO: verify other bits like linkage type. 1074 1075 return failure(walkResult.wasInterrupted()); 1076 } 1077 1078 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, 1079 StringRef name, FunctionType type, 1080 spirv::FunctionControl control, 1081 ArrayRef<NamedAttribute> attrs) { 1082 state.addAttribute(SymbolTable::getSymbolAttrName(), 1083 builder.getStringAttr(name)); 1084 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); 1085 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(), 1086 builder.getAttr<spirv::FunctionControlAttr>(control)); 1087 state.attributes.append(attrs.begin(), attrs.end()); 1088 state.addRegion(); 1089 } 1090 1091 //===----------------------------------------------------------------------===// 1092 // spirv.GLFClampOp 1093 //===----------------------------------------------------------------------===// 1094 1095 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser, 1096 OperationState &result) { 1097 return parseOneResultSameOperandTypeOp(parser, result); 1098 } 1099 void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } 1100 1101 //===----------------------------------------------------------------------===// 1102 // spirv.GLUClampOp 1103 //===----------------------------------------------------------------------===// 1104 1105 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser, 1106 OperationState &result) { 1107 return parseOneResultSameOperandTypeOp(parser, result); 1108 } 1109 void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } 1110 1111 //===----------------------------------------------------------------------===// 1112 // spirv.GLSClampOp 1113 //===----------------------------------------------------------------------===// 1114 1115 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser, 1116 OperationState &result) { 1117 return parseOneResultSameOperandTypeOp(parser, result); 1118 } 1119 void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } 1120 1121 //===----------------------------------------------------------------------===// 1122 // spirv.GLFmaOp 1123 //===----------------------------------------------------------------------===// 1124 1125 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) { 1126 return parseOneResultSameOperandTypeOp(parser, result); 1127 } 1128 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } 1129 1130 //===----------------------------------------------------------------------===// 1131 // spirv.GlobalVariable 1132 //===----------------------------------------------------------------------===// 1133 1134 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, 1135 Type type, StringRef name, 1136 unsigned descriptorSet, unsigned binding) { 1137 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); 1138 state.addAttribute( 1139 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet), 1140 builder.getI32IntegerAttr(descriptorSet)); 1141 state.addAttribute( 1142 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding), 1143 builder.getI32IntegerAttr(binding)); 1144 } 1145 1146 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, 1147 Type type, StringRef name, 1148 spirv::BuiltIn builtin) { 1149 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); 1150 state.addAttribute( 1151 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn), 1152 builder.getStringAttr(spirv::stringifyBuiltIn(builtin))); 1153 } 1154 1155 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser, 1156 OperationState &result) { 1157 // Parse variable name. 1158 StringAttr nameAttr; 1159 StringRef initializerAttrName = 1160 spirv::GlobalVariableOp::getInitializerAttrName(result.name); 1161 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 1162 result.attributes)) { 1163 return failure(); 1164 } 1165 1166 // Parse optional initializer 1167 if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) { 1168 FlatSymbolRefAttr initSymbol; 1169 if (parser.parseLParen() || 1170 parser.parseAttribute(initSymbol, Type(), initializerAttrName, 1171 result.attributes) || 1172 parser.parseRParen()) 1173 return failure(); 1174 } 1175 1176 if (parseVariableDecorations(parser, result)) { 1177 return failure(); 1178 } 1179 1180 Type type; 1181 StringRef typeAttrName = 1182 spirv::GlobalVariableOp::getTypeAttrName(result.name); 1183 auto loc = parser.getCurrentLocation(); 1184 if (parser.parseColonType(type)) { 1185 return failure(); 1186 } 1187 if (!llvm::isa<spirv::PointerType>(type)) { 1188 return parser.emitError(loc, "expected spirv.ptr type"); 1189 } 1190 result.addAttribute(typeAttrName, TypeAttr::get(type)); 1191 1192 return success(); 1193 } 1194 1195 void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) { 1196 SmallVector<StringRef, 4> elidedAttrs{ 1197 spirv::attributeName<spirv::StorageClass>()}; 1198 1199 // Print variable name. 1200 printer << ' '; 1201 printer.printSymbolName(getSymName()); 1202 elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); 1203 1204 StringRef initializerAttrName = this->getInitializerAttrName(); 1205 // Print optional initializer 1206 if (auto initializer = this->getInitializer()) { 1207 printer << " " << initializerAttrName << '('; 1208 printer.printSymbolName(*initializer); 1209 printer << ')'; 1210 elidedAttrs.push_back(initializerAttrName); 1211 } 1212 1213 StringRef typeAttrName = this->getTypeAttrName(); 1214 elidedAttrs.push_back(typeAttrName); 1215 spirv::printVariableDecorations(*this, printer, elidedAttrs); 1216 printer << " : " << getType(); 1217 } 1218 1219 LogicalResult spirv::GlobalVariableOp::verify() { 1220 if (!llvm::isa<spirv::PointerType>(getType())) 1221 return emitOpError("result must be of a !spv.ptr type"); 1222 1223 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the 1224 // object. It cannot be Generic. It must be the same as the Storage Class 1225 // operand of the Result Type." 1226 // Also, Function storage class is reserved by spirv.Variable. 1227 auto storageClass = this->storageClass(); 1228 if (storageClass == spirv::StorageClass::Generic || 1229 storageClass == spirv::StorageClass::Function) { 1230 return emitOpError("storage class cannot be '") 1231 << stringifyStorageClass(storageClass) << "'"; 1232 } 1233 1234 if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>( 1235 this->getInitializerAttrName())) { 1236 Operation *initOp = SymbolTable::lookupNearestSymbolFrom( 1237 (*this)->getParentOp(), init.getAttr()); 1238 // TODO: Currently only variable initialization with specialization 1239 // constants and other variables is supported. They could be normal 1240 // constants in the module scope as well. 1241 if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp, 1242 spirv::SpecConstantCompositeOp>(initOp)) { 1243 return emitOpError("initializer must be result of a " 1244 "spirv.SpecConstant or spirv.GlobalVariable or " 1245 "spirv.SpecConstantCompositeOp op"); 1246 } 1247 } 1248 1249 return success(); 1250 } 1251 1252 //===----------------------------------------------------------------------===// 1253 // spirv.INTEL.SubgroupBlockRead 1254 //===----------------------------------------------------------------------===// 1255 1256 LogicalResult spirv::INTELSubgroupBlockReadOp::verify() { 1257 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) 1258 return failure(); 1259 1260 return success(); 1261 } 1262 1263 //===----------------------------------------------------------------------===// 1264 // spirv.INTEL.SubgroupBlockWrite 1265 //===----------------------------------------------------------------------===// 1266 1267 ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser, 1268 OperationState &result) { 1269 // Parse the storage class specification 1270 spirv::StorageClass storageClass; 1271 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; 1272 auto loc = parser.getCurrentLocation(); 1273 Type elementType; 1274 if (parseEnumStrAttr(storageClass, parser) || 1275 parser.parseOperandList(operandInfo, 2) || parser.parseColon() || 1276 parser.parseType(elementType)) { 1277 return failure(); 1278 } 1279 1280 auto ptrType = spirv::PointerType::get(elementType, storageClass); 1281 if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType)) 1282 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); 1283 1284 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, 1285 result.operands)) { 1286 return failure(); 1287 } 1288 return success(); 1289 } 1290 1291 void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) { 1292 printer << " " << getPtr() << ", " << getValue() << " : " 1293 << getValue().getType(); 1294 } 1295 1296 LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() { 1297 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) 1298 return failure(); 1299 1300 return success(); 1301 } 1302 1303 //===----------------------------------------------------------------------===// 1304 // spirv.IAddCarryOp 1305 //===----------------------------------------------------------------------===// 1306 1307 LogicalResult spirv::IAddCarryOp::verify() { 1308 return ::verifyArithmeticExtendedBinaryOp(*this); 1309 } 1310 1311 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser, 1312 OperationState &result) { 1313 return ::parseArithmeticExtendedBinaryOp(parser, result); 1314 } 1315 1316 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) { 1317 ::printArithmeticExtendedBinaryOp(*this, printer); 1318 } 1319 1320 //===----------------------------------------------------------------------===// 1321 // spirv.ISubBorrowOp 1322 //===----------------------------------------------------------------------===// 1323 1324 LogicalResult spirv::ISubBorrowOp::verify() { 1325 return ::verifyArithmeticExtendedBinaryOp(*this); 1326 } 1327 1328 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser, 1329 OperationState &result) { 1330 return ::parseArithmeticExtendedBinaryOp(parser, result); 1331 } 1332 1333 void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) { 1334 ::printArithmeticExtendedBinaryOp(*this, printer); 1335 } 1336 1337 //===----------------------------------------------------------------------===// 1338 // spirv.SMulExtended 1339 //===----------------------------------------------------------------------===// 1340 1341 LogicalResult spirv::SMulExtendedOp::verify() { 1342 return ::verifyArithmeticExtendedBinaryOp(*this); 1343 } 1344 1345 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser, 1346 OperationState &result) { 1347 return ::parseArithmeticExtendedBinaryOp(parser, result); 1348 } 1349 1350 void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) { 1351 ::printArithmeticExtendedBinaryOp(*this, printer); 1352 } 1353 1354 //===----------------------------------------------------------------------===// 1355 // spirv.UMulExtended 1356 //===----------------------------------------------------------------------===// 1357 1358 LogicalResult spirv::UMulExtendedOp::verify() { 1359 return ::verifyArithmeticExtendedBinaryOp(*this); 1360 } 1361 1362 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser, 1363 OperationState &result) { 1364 return ::parseArithmeticExtendedBinaryOp(parser, result); 1365 } 1366 1367 void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) { 1368 ::printArithmeticExtendedBinaryOp(*this, printer); 1369 } 1370 1371 //===----------------------------------------------------------------------===// 1372 // spirv.MemoryBarrierOp 1373 //===----------------------------------------------------------------------===// 1374 1375 LogicalResult spirv::MemoryBarrierOp::verify() { 1376 return verifyMemorySemantics(getOperation(), getMemorySemantics()); 1377 } 1378 1379 //===----------------------------------------------------------------------===// 1380 // spirv.module 1381 //===----------------------------------------------------------------------===// 1382 1383 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, 1384 std::optional<StringRef> name) { 1385 OpBuilder::InsertionGuard guard(builder); 1386 builder.createBlock(state.addRegion()); 1387 if (name) { 1388 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(), 1389 builder.getStringAttr(*name)); 1390 } 1391 } 1392 1393 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, 1394 spirv::AddressingModel addressingModel, 1395 spirv::MemoryModel memoryModel, 1396 std::optional<VerCapExtAttr> vceTriple, 1397 std::optional<StringRef> name) { 1398 state.addAttribute( 1399 "addressing_model", 1400 builder.getAttr<spirv::AddressingModelAttr>(addressingModel)); 1401 state.addAttribute("memory_model", 1402 builder.getAttr<spirv::MemoryModelAttr>(memoryModel)); 1403 OpBuilder::InsertionGuard guard(builder); 1404 builder.createBlock(state.addRegion()); 1405 if (vceTriple) 1406 state.addAttribute(getVCETripleAttrName(), *vceTriple); 1407 if (name) 1408 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 1409 builder.getStringAttr(*name)); 1410 } 1411 1412 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser, 1413 OperationState &result) { 1414 Region *body = result.addRegion(); 1415 1416 // If the name is present, parse it. 1417 StringAttr nameAttr; 1418 (void)parser.parseOptionalSymbolName( 1419 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes); 1420 1421 // Parse attributes 1422 spirv::AddressingModel addrModel; 1423 spirv::MemoryModel memoryModel; 1424 if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser, 1425 result) || 1426 spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser, 1427 result)) 1428 return failure(); 1429 1430 if (succeeded(parser.parseOptionalKeyword("requires"))) { 1431 spirv::VerCapExtAttr vceTriple; 1432 if (parser.parseAttribute(vceTriple, 1433 spirv::ModuleOp::getVCETripleAttrName(), 1434 result.attributes)) 1435 return failure(); 1436 } 1437 1438 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) || 1439 parser.parseRegion(*body, /*arguments=*/{})) 1440 return failure(); 1441 1442 // Make sure we have at least one block. 1443 if (body->empty()) 1444 body->push_back(new Block()); 1445 1446 return success(); 1447 } 1448 1449 void spirv::ModuleOp::print(OpAsmPrinter &printer) { 1450 if (std::optional<StringRef> name = getName()) { 1451 printer << ' '; 1452 printer.printSymbolName(*name); 1453 } 1454 1455 SmallVector<StringRef, 2> elidedAttrs; 1456 1457 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " " 1458 << spirv::stringifyMemoryModel(getMemoryModel()); 1459 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>(); 1460 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>(); 1461 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName, 1462 mlir::SymbolTable::getSymbolAttrName()}); 1463 1464 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) { 1465 printer << " requires " << *triple; 1466 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName()); 1467 } 1468 1469 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs); 1470 printer << ' '; 1471 printer.printRegion(getRegion()); 1472 } 1473 1474 LogicalResult spirv::ModuleOp::verifyRegions() { 1475 Dialect *dialect = (*this)->getDialect(); 1476 DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp> 1477 entryPoints; 1478 mlir::SymbolTable table(*this); 1479 1480 for (auto &op : *getBody()) { 1481 if (op.getDialect() != dialect) 1482 return op.emitError("'spirv.module' can only contain spirv.* ops"); 1483 1484 // For EntryPoint op, check that the function and execution model is not 1485 // duplicated in EntryPointOps. Also verify that the interface specified 1486 // comes from globalVariables here to make this check cheaper. 1487 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) { 1488 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn()); 1489 if (!funcOp) { 1490 return entryPointOp.emitError("function '") 1491 << entryPointOp.getFn() << "' not found in 'spirv.module'"; 1492 } 1493 if (auto interface = entryPointOp.getInterface()) { 1494 for (Attribute varRef : interface) { 1495 auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef); 1496 if (!varSymRef) { 1497 return entryPointOp.emitError( 1498 "expected symbol reference for interface " 1499 "specification instead of '") 1500 << varRef; 1501 } 1502 auto variableOp = 1503 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue()); 1504 if (!variableOp) { 1505 return entryPointOp.emitError("expected spirv.GlobalVariable " 1506 "symbol reference instead of'") 1507 << varSymRef << "'"; 1508 } 1509 } 1510 } 1511 1512 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>( 1513 funcOp, entryPointOp.getExecutionModel()); 1514 if (!entryPoints.try_emplace(key, entryPointOp).second) 1515 return entryPointOp.emitError("duplicate of a previous EntryPointOp"); 1516 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) { 1517 // If the function is external and does not have 'Import' 1518 // linkage_attributes(LinkageAttributes), throw an error. 'Import' 1519 // LinkageAttributes is used to import external functions. 1520 auto linkageAttr = funcOp.getLinkageAttributes(); 1521 auto hasImportLinkage = 1522 linkageAttr && (linkageAttr.value().getLinkageType().getValue() == 1523 spirv::LinkageType::Import); 1524 if (funcOp.isExternal() && !hasImportLinkage) 1525 return op.emitError( 1526 "'spirv.module' cannot contain external functions " 1527 "without 'Import' linkage_attributes (LinkageAttributes)"); 1528 1529 // TODO: move this check to spirv.func. 1530 for (auto &block : funcOp) 1531 for (auto &op : block) { 1532 if (op.getDialect() != dialect) 1533 return op.emitError( 1534 "functions in 'spirv.module' can only contain spirv.* ops"); 1535 } 1536 } 1537 } 1538 1539 return success(); 1540 } 1541 1542 //===----------------------------------------------------------------------===// 1543 // spirv.mlir.referenceof 1544 //===----------------------------------------------------------------------===// 1545 1546 LogicalResult spirv::ReferenceOfOp::verify() { 1547 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom( 1548 (*this)->getParentOp(), getSpecConstAttr()); 1549 Type constType; 1550 1551 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym); 1552 if (specConstOp) 1553 constType = specConstOp.getDefaultValue().getType(); 1554 1555 auto specConstCompositeOp = 1556 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym); 1557 if (specConstCompositeOp) 1558 constType = specConstCompositeOp.getType(); 1559 1560 if (!specConstOp && !specConstCompositeOp) 1561 return emitOpError( 1562 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol"); 1563 1564 if (getReference().getType() != constType) 1565 return emitOpError("result type mismatch with the referenced " 1566 "specialization constant's type"); 1567 1568 return success(); 1569 } 1570 1571 //===----------------------------------------------------------------------===// 1572 // spirv.SpecConstant 1573 //===----------------------------------------------------------------------===// 1574 1575 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser, 1576 OperationState &result) { 1577 StringAttr nameAttr; 1578 Attribute valueAttr; 1579 StringRef defaultValueAttrName = 1580 spirv::SpecConstantOp::getDefaultValueAttrName(result.name); 1581 1582 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 1583 result.attributes)) 1584 return failure(); 1585 1586 // Parse optional spec_id. 1587 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) { 1588 IntegerAttr specIdAttr; 1589 if (parser.parseLParen() || 1590 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) || 1591 parser.parseRParen()) 1592 return failure(); 1593 } 1594 1595 if (parser.parseEqual() || 1596 parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes)) 1597 return failure(); 1598 1599 return success(); 1600 } 1601 1602 void spirv::SpecConstantOp::print(OpAsmPrinter &printer) { 1603 printer << ' '; 1604 printer.printSymbolName(getSymName()); 1605 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName)) 1606 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; 1607 printer << " = " << getDefaultValue(); 1608 } 1609 1610 LogicalResult spirv::SpecConstantOp::verify() { 1611 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName)) 1612 if (specID.getValue().isNegative()) 1613 return emitOpError("SpecId cannot be negative"); 1614 1615 auto value = getDefaultValue(); 1616 if (llvm::isa<IntegerAttr, FloatAttr>(value)) { 1617 // Make sure bitwidth is allowed. 1618 if (!llvm::isa<spirv::SPIRVType>(value.getType())) 1619 return emitOpError("default value bitwidth disallowed"); 1620 return success(); 1621 } 1622 return emitOpError( 1623 "default value can only be a bool, integer, or float scalar"); 1624 } 1625 1626 //===----------------------------------------------------------------------===// 1627 // spirv.VectorShuffle 1628 //===----------------------------------------------------------------------===// 1629 1630 LogicalResult spirv::VectorShuffleOp::verify() { 1631 VectorType resultType = llvm::cast<VectorType>(getType()); 1632 1633 size_t numResultElements = resultType.getNumElements(); 1634 if (numResultElements != getComponents().size()) 1635 return emitOpError("result type element count (") 1636 << numResultElements 1637 << ") mismatch with the number of component selectors (" 1638 << getComponents().size() << ")"; 1639 1640 size_t totalSrcElements = 1641 llvm::cast<VectorType>(getVector1().getType()).getNumElements() + 1642 llvm::cast<VectorType>(getVector2().getType()).getNumElements(); 1643 1644 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) { 1645 uint32_t index = selector.getZExtValue(); 1646 if (index >= totalSrcElements && 1647 index != std::numeric_limits<uint32_t>().max()) 1648 return emitOpError("component selector ") 1649 << index << " out of range: expected to be in [0, " 1650 << totalSrcElements << ") or 0xffffffff"; 1651 } 1652 return success(); 1653 } 1654 1655 //===----------------------------------------------------------------------===// 1656 // spirv.MatrixTimesScalar 1657 //===----------------------------------------------------------------------===// 1658 1659 LogicalResult spirv::MatrixTimesScalarOp::verify() { 1660 Type elementType = 1661 llvm::TypeSwitch<Type, Type>(getMatrix().getType()) 1662 .Case<spirv::CooperativeMatrixType, spirv::MatrixType>( 1663 [](auto matrixType) { return matrixType.getElementType(); }) 1664 .Default([](Type) { return nullptr; }); 1665 1666 assert(elementType && "Unhandled type"); 1667 1668 // Check that the scalar type is the same as the matrix element type. 1669 if (getScalar().getType() != elementType) 1670 return emitOpError("input matrix components' type and scaling value must " 1671 "have the same type"); 1672 1673 return success(); 1674 } 1675 1676 //===----------------------------------------------------------------------===// 1677 // spirv.Transpose 1678 //===----------------------------------------------------------------------===// 1679 1680 LogicalResult spirv::TransposeOp::verify() { 1681 auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType()); 1682 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType()); 1683 1684 // Verify that the input and output matrices have correct shapes. 1685 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) 1686 return emitError("input matrix rows count must be equal to " 1687 "output matrix columns count"); 1688 1689 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows()) 1690 return emitError("input matrix columns count must be equal to " 1691 "output matrix rows count"); 1692 1693 // Verify that the input and output matrices have the same component type 1694 if (inputMatrix.getElementType() != resultMatrix.getElementType()) 1695 return emitError("input and output matrices must have the same " 1696 "component type"); 1697 1698 return success(); 1699 } 1700 1701 //===----------------------------------------------------------------------===// 1702 // spirv.MatrixTimesVector 1703 //===----------------------------------------------------------------------===// 1704 1705 LogicalResult spirv::MatrixTimesVectorOp::verify() { 1706 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType()); 1707 auto vectorType = llvm::cast<VectorType>(getVector().getType()); 1708 auto resultType = llvm::cast<VectorType>(getType()); 1709 1710 if (matrixType.getNumColumns() != vectorType.getNumElements()) 1711 return emitOpError("matrix columns (") 1712 << matrixType.getNumColumns() << ") must match vector operand size (" 1713 << vectorType.getNumElements() << ")"; 1714 1715 if (resultType.getNumElements() != matrixType.getNumRows()) 1716 return emitOpError("result size (") 1717 << resultType.getNumElements() << ") must match the matrix rows (" 1718 << matrixType.getNumRows() << ")"; 1719 1720 if (matrixType.getElementType() != resultType.getElementType()) 1721 return emitOpError("matrix and result element types must match"); 1722 1723 return success(); 1724 } 1725 1726 //===----------------------------------------------------------------------===// 1727 // spirv.VectorTimesMatrix 1728 //===----------------------------------------------------------------------===// 1729 1730 LogicalResult spirv::VectorTimesMatrixOp::verify() { 1731 auto vectorType = llvm::cast<VectorType>(getVector().getType()); 1732 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType()); 1733 auto resultType = llvm::cast<VectorType>(getType()); 1734 1735 if (matrixType.getNumRows() != vectorType.getNumElements()) 1736 return emitOpError("number of components in vector must equal the number " 1737 "of components in each column in matrix"); 1738 1739 if (resultType.getNumElements() != matrixType.getNumColumns()) 1740 return emitOpError("number of columns in matrix must equal the number of " 1741 "components in result"); 1742 1743 if (matrixType.getElementType() != resultType.getElementType()) 1744 return emitOpError("matrix must be a matrix with the same component type " 1745 "as the component type in result"); 1746 1747 return success(); 1748 } 1749 1750 //===----------------------------------------------------------------------===// 1751 // spirv.MatrixTimesMatrix 1752 //===----------------------------------------------------------------------===// 1753 1754 LogicalResult spirv::MatrixTimesMatrixOp::verify() { 1755 auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType()); 1756 auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType()); 1757 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType()); 1758 1759 // left matrix columns' count and right matrix rows' count must be equal 1760 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) 1761 return emitError("left matrix columns' count must be equal to " 1762 "the right matrix rows' count"); 1763 1764 // right and result matrices columns' count must be the same 1765 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns()) 1766 return emitError( 1767 "right and result matrices must have equal columns' count"); 1768 1769 // right and result matrices component type must be the same 1770 if (rightMatrix.getElementType() != resultMatrix.getElementType()) 1771 return emitError("right and result matrices' component type must" 1772 " be the same"); 1773 1774 // left and result matrices component type must be the same 1775 if (leftMatrix.getElementType() != resultMatrix.getElementType()) 1776 return emitError("left and result matrices' component type" 1777 " must be the same"); 1778 1779 // left and result matrices rows count must be the same 1780 if (leftMatrix.getNumRows() != resultMatrix.getNumRows()) 1781 return emitError("left and result matrices must have equal rows' count"); 1782 1783 return success(); 1784 } 1785 1786 //===----------------------------------------------------------------------===// 1787 // spirv.SpecConstantComposite 1788 //===----------------------------------------------------------------------===// 1789 1790 ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser, 1791 OperationState &result) { 1792 1793 StringAttr compositeName; 1794 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(), 1795 result.attributes)) 1796 return failure(); 1797 1798 if (parser.parseLParen()) 1799 return failure(); 1800 1801 SmallVector<Attribute, 4> constituents; 1802 1803 do { 1804 // The name of the constituent attribute isn't important 1805 const char *attrName = "spec_const"; 1806 FlatSymbolRefAttr specConstRef; 1807 NamedAttrList attrs; 1808 1809 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs)) 1810 return failure(); 1811 1812 constituents.push_back(specConstRef); 1813 } while (!parser.parseOptionalComma()); 1814 1815 if (parser.parseRParen()) 1816 return failure(); 1817 1818 StringAttr compositeSpecConstituentsName = 1819 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name); 1820 result.addAttribute(compositeSpecConstituentsName, 1821 parser.getBuilder().getArrayAttr(constituents)); 1822 1823 Type type; 1824 if (parser.parseColonType(type)) 1825 return failure(); 1826 1827 StringAttr typeAttrName = 1828 spirv::SpecConstantCompositeOp::getTypeAttrName(result.name); 1829 result.addAttribute(typeAttrName, TypeAttr::get(type)); 1830 1831 return success(); 1832 } 1833 1834 void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) { 1835 printer << " "; 1836 printer.printSymbolName(getSymName()); 1837 printer << " ("; 1838 auto constituents = this->getConstituents().getValue(); 1839 1840 if (!constituents.empty()) 1841 llvm::interleaveComma(constituents, printer); 1842 1843 printer << ") : " << getType(); 1844 } 1845 1846 LogicalResult spirv::SpecConstantCompositeOp::verify() { 1847 auto cType = llvm::dyn_cast<spirv::CompositeType>(getType()); 1848 auto constituents = this->getConstituents().getValue(); 1849 1850 if (!cType) 1851 return emitError("result type must be a composite type, but provided ") 1852 << getType(); 1853 1854 if (llvm::isa<spirv::CooperativeMatrixType>(cType)) 1855 return emitError("unsupported composite type ") << cType; 1856 if (constituents.size() != cType.getNumElements()) 1857 return emitError("has incorrect number of operands: expected ") 1858 << cType.getNumElements() << ", but provided " 1859 << constituents.size(); 1860 1861 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { 1862 auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]); 1863 1864 auto constituentSpecConstOp = 1865 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom( 1866 (*this)->getParentOp(), constituent.getAttr())); 1867 1868 if (constituentSpecConstOp.getDefaultValue().getType() != 1869 cType.getElementType(index)) 1870 return emitError("has incorrect types of operands: expected ") 1871 << cType.getElementType(index) << ", but provided " 1872 << constituentSpecConstOp.getDefaultValue().getType(); 1873 } 1874 1875 return success(); 1876 } 1877 1878 //===----------------------------------------------------------------------===// 1879 // spirv.SpecConstantOperation 1880 //===----------------------------------------------------------------------===// 1881 1882 ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser, 1883 OperationState &result) { 1884 Region *body = result.addRegion(); 1885 1886 if (parser.parseKeyword("wraps")) 1887 return failure(); 1888 1889 body->push_back(new Block); 1890 Block &block = body->back(); 1891 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 1892 1893 if (!wrappedOp) 1894 return failure(); 1895 1896 OpBuilder builder(parser.getContext()); 1897 builder.setInsertionPointToEnd(&block); 1898 builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0)); 1899 result.location = wrappedOp->getLoc(); 1900 1901 result.addTypes(wrappedOp->getResult(0).getType()); 1902 1903 if (parser.parseOptionalAttrDict(result.attributes)) 1904 return failure(); 1905 1906 return success(); 1907 } 1908 1909 void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) { 1910 printer << " wraps "; 1911 printer.printGenericOp(&getBody().front().front()); 1912 } 1913 1914 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() { 1915 Block &block = getRegion().getBlocks().front(); 1916 1917 if (block.getOperations().size() != 2) 1918 return emitOpError("expected exactly 2 nested ops"); 1919 1920 Operation &enclosedOp = block.getOperations().front(); 1921 1922 if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>()) 1923 return emitOpError("invalid enclosed op"); 1924 1925 for (auto operand : enclosedOp.getOperands()) 1926 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp, 1927 spirv::SpecConstantOperationOp>(operand.getDefiningOp())) 1928 return emitOpError( 1929 "invalid operand, must be defined by a constant operation"); 1930 1931 return success(); 1932 } 1933 1934 //===----------------------------------------------------------------------===// 1935 // spirv.GL.FrexpStruct 1936 //===----------------------------------------------------------------------===// 1937 1938 LogicalResult spirv::GLFrexpStructOp::verify() { 1939 spirv::StructType structTy = 1940 llvm::dyn_cast<spirv::StructType>(getResult().getType()); 1941 1942 if (structTy.getNumElements() != 2) 1943 return emitError("result type must be a struct type with two memebers"); 1944 1945 Type significandTy = structTy.getElementType(0); 1946 Type exponentTy = structTy.getElementType(1); 1947 VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy); 1948 IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy); 1949 1950 Type operandTy = getOperand().getType(); 1951 VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy); 1952 FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy); 1953 1954 if (significandTy != operandTy) 1955 return emitError("member zero of the resulting struct type must be the " 1956 "same type as the operand"); 1957 1958 if (exponentVecTy) { 1959 IntegerType componentIntTy = 1960 llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType()); 1961 if (!componentIntTy || componentIntTy.getWidth() != 32) 1962 return emitError("member one of the resulting struct type must" 1963 "be a scalar or vector of 32 bit integer type"); 1964 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) { 1965 return emitError("member one of the resulting struct type " 1966 "must be a scalar or vector of 32 bit integer type"); 1967 } 1968 1969 // Check that the two member types have the same number of components 1970 if (operandVecTy && exponentVecTy && 1971 (exponentVecTy.getNumElements() == operandVecTy.getNumElements())) 1972 return success(); 1973 1974 if (operandFTy && exponentIntTy) 1975 return success(); 1976 1977 return emitError("member one of the resulting struct type must have the same " 1978 "number of components as the operand type"); 1979 } 1980 1981 //===----------------------------------------------------------------------===// 1982 // spirv.GL.Ldexp 1983 //===----------------------------------------------------------------------===// 1984 1985 LogicalResult spirv::GLLdexpOp::verify() { 1986 Type significandType = getX().getType(); 1987 Type exponentType = getExp().getType(); 1988 1989 if (llvm::isa<FloatType>(significandType) != 1990 llvm::isa<IntegerType>(exponentType)) 1991 return emitOpError("operands must both be scalars or vectors"); 1992 1993 auto getNumElements = [](Type type) -> unsigned { 1994 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) 1995 return vectorType.getNumElements(); 1996 return 1; 1997 }; 1998 1999 if (getNumElements(significandType) != getNumElements(exponentType)) 2000 return emitOpError("operands must have the same number of elements"); 2001 2002 return success(); 2003 } 2004 2005 //===----------------------------------------------------------------------===// 2006 // spirv.ImageDrefGather 2007 //===----------------------------------------------------------------------===// 2008 2009 LogicalResult spirv::ImageDrefGatherOp::verify() { 2010 VectorType resultType = llvm::cast<VectorType>(getResult().getType()); 2011 auto sampledImageType = 2012 llvm::cast<spirv::SampledImageType>(getSampledimage().getType()); 2013 auto imageType = 2014 llvm::cast<spirv::ImageType>(sampledImageType.getImageType()); 2015 2016 if (resultType.getNumElements() != 4) 2017 return emitOpError("result type must be a vector of four components"); 2018 2019 Type elementType = resultType.getElementType(); 2020 Type sampledElementType = imageType.getElementType(); 2021 if (!llvm::isa<NoneType>(sampledElementType) && 2022 elementType != sampledElementType) 2023 return emitOpError( 2024 "the component type of result must be the same as sampled type of the " 2025 "underlying image type"); 2026 2027 spirv::Dim imageDim = imageType.getDim(); 2028 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo(); 2029 2030 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube && 2031 imageDim != spirv::Dim::Rect) 2032 return emitOpError( 2033 "the Dim operand of the underlying image type must be 2D, Cube, or " 2034 "Rect"); 2035 2036 if (imageMS != spirv::ImageSamplingInfo::SingleSampled) 2037 return emitOpError("the MS operand of the underlying image type must be 0"); 2038 2039 spirv::ImageOperandsAttr attr = getImageoperandsAttr(); 2040 auto operandArguments = getOperandArguments(); 2041 2042 return verifyImageOperands(*this, attr, operandArguments); 2043 } 2044 2045 //===----------------------------------------------------------------------===// 2046 // spirv.ShiftLeftLogicalOp 2047 //===----------------------------------------------------------------------===// 2048 2049 LogicalResult spirv::ShiftLeftLogicalOp::verify() { 2050 return verifyShiftOp(*this); 2051 } 2052 2053 //===----------------------------------------------------------------------===// 2054 // spirv.ShiftRightArithmeticOp 2055 //===----------------------------------------------------------------------===// 2056 2057 LogicalResult spirv::ShiftRightArithmeticOp::verify() { 2058 return verifyShiftOp(*this); 2059 } 2060 2061 //===----------------------------------------------------------------------===// 2062 // spirv.ShiftRightLogicalOp 2063 //===----------------------------------------------------------------------===// 2064 2065 LogicalResult spirv::ShiftRightLogicalOp::verify() { 2066 return verifyShiftOp(*this); 2067 } 2068 2069 //===----------------------------------------------------------------------===// 2070 // spirv.ImageQuerySize 2071 //===----------------------------------------------------------------------===// 2072 2073 LogicalResult spirv::ImageQuerySizeOp::verify() { 2074 spirv::ImageType imageType = 2075 llvm::cast<spirv::ImageType>(getImage().getType()); 2076 Type resultType = getResult().getType(); 2077 2078 spirv::Dim dim = imageType.getDim(); 2079 spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo(); 2080 spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo(); 2081 switch (dim) { 2082 case spirv::Dim::Dim1D: 2083 case spirv::Dim::Dim2D: 2084 case spirv::Dim::Dim3D: 2085 case spirv::Dim::Cube: 2086 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled && 2087 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown && 2088 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler) 2089 return emitError( 2090 "if Dim is 1D, 2D, 3D, or Cube, " 2091 "it must also have either an MS of 1 or a Sampled of 0 or 2"); 2092 break; 2093 case spirv::Dim::Buffer: 2094 case spirv::Dim::Rect: 2095 break; 2096 default: 2097 return emitError("the Dim operand of the image type must " 2098 "be 1D, 2D, 3D, Buffer, Cube, or Rect"); 2099 } 2100 2101 unsigned componentNumber = 0; 2102 switch (dim) { 2103 case spirv::Dim::Dim1D: 2104 case spirv::Dim::Buffer: 2105 componentNumber = 1; 2106 break; 2107 case spirv::Dim::Dim2D: 2108 case spirv::Dim::Cube: 2109 case spirv::Dim::Rect: 2110 componentNumber = 2; 2111 break; 2112 case spirv::Dim::Dim3D: 2113 componentNumber = 3; 2114 break; 2115 default: 2116 break; 2117 } 2118 2119 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed) 2120 componentNumber += 1; 2121 2122 unsigned resultComponentNumber = 1; 2123 if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType)) 2124 resultComponentNumber = resultVectorType.getNumElements(); 2125 2126 if (componentNumber != resultComponentNumber) 2127 return emitError("expected the result to have ") 2128 << componentNumber << " component(s), but found " 2129 << resultComponentNumber << " component(s)"; 2130 2131 return success(); 2132 } 2133 2134 //===----------------------------------------------------------------------===// 2135 // spirv.VectorTimesScalarOp 2136 //===----------------------------------------------------------------------===// 2137 2138 LogicalResult spirv::VectorTimesScalarOp::verify() { 2139 if (getVector().getType() != getType()) 2140 return emitOpError("vector operand and result type mismatch"); 2141 auto scalarType = llvm::cast<VectorType>(getType()).getElementType(); 2142 if (getScalar().getType() != scalarType) 2143 return emitOpError("scalar operand and result element type match"); 2144 return success(); 2145 } 2146