1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities used to lower to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/Dialect/Vector/IR/VectorOps.h" 23 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 24 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 25 #include "mlir/IR/BuiltinTypes.h" 26 #include "mlir/IR/Operation.h" 27 #include "mlir/IR/PatternMatch.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/DialectConversion.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "mlir/Transforms/OneToNTypeConversion.h" 33 #include "llvm/ADT/STLExtras.h" 34 #include "llvm/ADT/SmallVector.h" 35 #include "llvm/ADT/StringExtras.h" 36 #include "llvm/Support/Debug.h" 37 #include "llvm/Support/LogicalResult.h" 38 #include "llvm/Support/MathExtras.h" 39 40 #include <functional> 41 #include <optional> 42 43 #define DEBUG_TYPE "mlir-spirv-conversion" 44 45 using namespace mlir; 46 47 namespace { 48 49 //===----------------------------------------------------------------------===// 50 // Utility functions 51 //===----------------------------------------------------------------------===// 52 53 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) { 54 LLVM_DEBUG(llvm::dbgs() << "Get target shape\n"); 55 if (vecType.isScalable()) { 56 LLVM_DEBUG(llvm::dbgs() 57 << "--scalable vectors are not supported -> BAIL\n"); 58 return std::nullopt; 59 } 60 SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape()); 61 std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>( 62 1, mlir::spirv::getComputeVectorSize(vecType.getShape().back())); 63 if (!targetShape) { 64 LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n"); 65 return std::nullopt; 66 } 67 auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape); 68 if (!maybeShapeRatio) { 69 LLVM_DEBUG(llvm::dbgs() 70 << "--could not compute integral shape ratio -> BAIL\n"); 71 return std::nullopt; 72 } 73 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { 74 LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n"); 75 return std::nullopt; 76 } 77 LLVM_DEBUG(llvm::dbgs() 78 << "--found an integral shape ratio to unroll to -> SUCCESS\n"); 79 return targetShape; 80 } 81 82 /// Checks that `candidates` extension requirements are possible to be satisfied 83 /// with the given `targetEnv`. 84 /// 85 /// `candidates` is a vector of vector for extension requirements following 86 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 87 /// convention. 88 template <typename LabelT> 89 static LogicalResult checkExtensionRequirements( 90 LabelT label, const spirv::TargetEnv &targetEnv, 91 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { 92 for (const auto &ors : candidates) { 93 if (targetEnv.allows(ors)) 94 continue; 95 96 LLVM_DEBUG({ 97 SmallVector<StringRef> extStrings; 98 for (spirv::Extension ext : ors) 99 extStrings.push_back(spirv::stringifyExtension(ext)); 100 101 llvm::dbgs() << label << " illegal: requires at least one extension in [" 102 << llvm::join(extStrings, ", ") 103 << "] but none allowed in target environment\n"; 104 }); 105 return failure(); 106 } 107 return success(); 108 } 109 110 /// Checks that `candidates`capability requirements are possible to be satisfied 111 /// with the given `isAllowedFn`. 112 /// 113 /// `candidates` is a vector of vector for capability requirements following 114 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 115 /// convention. 116 template <typename LabelT> 117 static LogicalResult checkCapabilityRequirements( 118 LabelT label, const spirv::TargetEnv &targetEnv, 119 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { 120 for (const auto &ors : candidates) { 121 if (targetEnv.allows(ors)) 122 continue; 123 124 LLVM_DEBUG({ 125 SmallVector<StringRef> capStrings; 126 for (spirv::Capability cap : ors) 127 capStrings.push_back(spirv::stringifyCapability(cap)); 128 129 llvm::dbgs() << label << " illegal: requires at least one capability in [" 130 << llvm::join(capStrings, ", ") 131 << "] but none allowed in target environment\n"; 132 }); 133 return failure(); 134 } 135 return success(); 136 } 137 138 /// Returns true if the given `storageClass` needs explicit layout when used in 139 /// Shader environments. 140 static bool needsExplicitLayout(spirv::StorageClass storageClass) { 141 switch (storageClass) { 142 case spirv::StorageClass::PhysicalStorageBuffer: 143 case spirv::StorageClass::PushConstant: 144 case spirv::StorageClass::StorageBuffer: 145 case spirv::StorageClass::Uniform: 146 return true; 147 default: 148 return false; 149 } 150 } 151 152 /// Wraps the given `elementType` in a struct and gets the pointer to the 153 /// struct. This is used to satisfy Vulkan interface requirements. 154 static spirv::PointerType 155 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) { 156 auto structType = needsExplicitLayout(storageClass) 157 ? spirv::StructType::get(elementType, /*offsetInfo=*/0) 158 : spirv::StructType::get(elementType); 159 return spirv::PointerType::get(structType, storageClass); 160 } 161 162 //===----------------------------------------------------------------------===// 163 // Type Conversion 164 //===----------------------------------------------------------------------===// 165 166 static spirv::ScalarType getIndexType(MLIRContext *ctx, 167 const SPIRVConversionOptions &options) { 168 return cast<spirv::ScalarType>( 169 IntegerType::get(ctx, options.use64bitIndex ? 64 : 32)); 170 } 171 172 // TODO: This is a utility function that should probably be exposed by the 173 // SPIR-V dialect. Keeping it local till the use case arises. 174 static std::optional<int64_t> 175 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { 176 if (isa<spirv::ScalarType>(type)) { 177 auto bitWidth = type.getIntOrFloatBitWidth(); 178 // According to the SPIR-V spec: 179 // "There is no physical size or bit pattern defined for values with boolean 180 // type. If they are stored (in conjunction with OpVariable), they can only 181 // be used with logical addressing operations, not physical, and only with 182 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, 183 // Private, Function, Input, and Output." 184 if (bitWidth == 1) 185 return std::nullopt; 186 return bitWidth / 8; 187 } 188 189 if (auto complexType = dyn_cast<ComplexType>(type)) { 190 auto elementSize = getTypeNumBytes(options, complexType.getElementType()); 191 if (!elementSize) 192 return std::nullopt; 193 return 2 * *elementSize; 194 } 195 196 if (auto vecType = dyn_cast<VectorType>(type)) { 197 auto elementSize = getTypeNumBytes(options, vecType.getElementType()); 198 if (!elementSize) 199 return std::nullopt; 200 return vecType.getNumElements() * *elementSize; 201 } 202 203 if (auto memRefType = dyn_cast<MemRefType>(type)) { 204 // TODO: Layout should also be controlled by the ABI attributes. For now 205 // using the layout from MemRef. 206 int64_t offset; 207 SmallVector<int64_t, 4> strides; 208 if (!memRefType.hasStaticShape() || 209 failed(memRefType.getStridesAndOffset(strides, offset))) 210 return std::nullopt; 211 212 // To get the size of the memref object in memory, the total size is the 213 // max(stride * dimension-size) computed for all dimensions times the size 214 // of the element. 215 auto elementSize = getTypeNumBytes(options, memRefType.getElementType()); 216 if (!elementSize) 217 return std::nullopt; 218 219 if (memRefType.getRank() == 0) 220 return elementSize; 221 222 auto dims = memRefType.getShape(); 223 if (llvm::is_contained(dims, ShapedType::kDynamic) || 224 ShapedType::isDynamic(offset) || 225 llvm::is_contained(strides, ShapedType::kDynamic)) 226 return std::nullopt; 227 228 int64_t memrefSize = -1; 229 for (const auto &shape : enumerate(dims)) 230 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); 231 232 return (offset + memrefSize) * *elementSize; 233 } 234 235 if (auto tensorType = dyn_cast<TensorType>(type)) { 236 if (!tensorType.hasStaticShape()) 237 return std::nullopt; 238 239 auto elementSize = getTypeNumBytes(options, tensorType.getElementType()); 240 if (!elementSize) 241 return std::nullopt; 242 243 int64_t size = *elementSize; 244 for (auto shape : tensorType.getShape()) 245 size *= shape; 246 247 return size; 248 } 249 250 // TODO: Add size computation for other types. 251 return std::nullopt; 252 } 253 254 /// Converts a scalar `type` to a suitable type under the given `targetEnv`. 255 static Type 256 convertScalarType(const spirv::TargetEnv &targetEnv, 257 const SPIRVConversionOptions &options, spirv::ScalarType type, 258 std::optional<spirv::StorageClass> storageClass = {}) { 259 // Get extension and capability requirements for the given type. 260 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 261 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 262 type.getExtensions(extensions, storageClass); 263 type.getCapabilities(capabilities, storageClass); 264 265 // If all requirements are met, then we can accept this type as-is. 266 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 267 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 268 return type; 269 270 // Otherwise we need to adjust the type, which really means adjusting the 271 // bitwidth given this is a scalar type. 272 if (!options.emulateLT32BitScalarTypes) 273 return nullptr; 274 275 // We only emulate narrower scalar types here and do not truncate results. 276 if (type.getIntOrFloatBitWidth() > 32) { 277 LLVM_DEBUG(llvm::dbgs() 278 << type 279 << " not converted to 32-bit for SPIR-V to avoid truncation\n"); 280 return nullptr; 281 } 282 283 if (auto floatType = dyn_cast<FloatType>(type)) { 284 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 285 return Builder(targetEnv.getContext()).getF32Type(); 286 } 287 288 auto intType = cast<IntegerType>(type); 289 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 290 return IntegerType::get(targetEnv.getContext(), /*width=*/32, 291 intType.getSignedness()); 292 } 293 294 /// Converts a sub-byte integer `type` to i32 regardless of target environment. 295 /// Returns a nullptr for unsupported integer types, including non sub-byte 296 /// types. 297 /// 298 /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use 299 /// the above given that these sub-byte types are not supported at all in 300 /// SPIR-V; there are no compute/storage capability for them like other 301 /// supported integer types. 302 static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, 303 IntegerType type) { 304 if (type.getWidth() > 8) { 305 LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n"); 306 return nullptr; 307 } 308 if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) { 309 LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n"); 310 return nullptr; 311 } 312 313 if (!llvm::isPowerOf2_32(type.getWidth())) { 314 LLVM_DEBUG(llvm::dbgs() 315 << "unsupported non-power-of-two bitwidth in sub-byte" << type 316 << "\n"); 317 return nullptr; 318 } 319 320 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 321 return IntegerType::get(type.getContext(), /*width=*/32, 322 type.getSignedness()); 323 } 324 325 /// Returns a type with the same shape but with any index element type converted 326 /// to the matching integer type. This is a noop when the element type is not 327 /// the index type. 328 static ShapedType 329 convertIndexElementType(ShapedType type, 330 const SPIRVConversionOptions &options) { 331 Type indexType = dyn_cast<IndexType>(type.getElementType()); 332 if (!indexType) 333 return type; 334 335 return type.clone(getIndexType(type.getContext(), options)); 336 } 337 338 /// Converts a vector `type` to a suitable type under the given `targetEnv`. 339 static Type 340 convertVectorType(const spirv::TargetEnv &targetEnv, 341 const SPIRVConversionOptions &options, VectorType type, 342 std::optional<spirv::StorageClass> storageClass = {}) { 343 type = cast<VectorType>(convertIndexElementType(type, options)); 344 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); 345 if (!scalarType) { 346 // If this is not a spec allowed scalar type, try to handle sub-byte integer 347 // types. 348 auto intType = dyn_cast<IntegerType>(type.getElementType()); 349 if (!intType) { 350 LLVM_DEBUG(llvm::dbgs() 351 << type 352 << " illegal: cannot convert non-scalar element type\n"); 353 return nullptr; 354 } 355 356 Type elementType = convertSubByteIntegerType(options, intType); 357 if (!elementType) 358 return nullptr; 359 360 if (type.getRank() <= 1 && type.getNumElements() == 1) 361 return elementType; 362 363 if (type.getNumElements() > 4) { 364 LLVM_DEBUG(llvm::dbgs() 365 << type << " illegal: > 4-element unimplemented\n"); 366 return nullptr; 367 } 368 369 return VectorType::get(type.getShape(), elementType); 370 } 371 372 if (type.getRank() <= 1 && type.getNumElements() == 1) 373 return convertScalarType(targetEnv, options, scalarType, storageClass); 374 375 if (!spirv::CompositeType::isValid(type)) { 376 LLVM_DEBUG(llvm::dbgs() 377 << type << " illegal: not a valid composite type\n"); 378 return nullptr; 379 } 380 381 // Get extension and capability requirements for the given type. 382 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 383 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 384 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass); 385 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass); 386 387 // If all requirements are met, then we can accept this type as-is. 388 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 389 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 390 return type; 391 392 auto elementType = 393 convertScalarType(targetEnv, options, scalarType, storageClass); 394 if (elementType) 395 return VectorType::get(type.getShape(), elementType); 396 return nullptr; 397 } 398 399 static Type 400 convertComplexType(const spirv::TargetEnv &targetEnv, 401 const SPIRVConversionOptions &options, ComplexType type, 402 std::optional<spirv::StorageClass> storageClass = {}) { 403 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); 404 if (!scalarType) { 405 LLVM_DEBUG(llvm::dbgs() 406 << type << " illegal: cannot convert non-scalar element type\n"); 407 return nullptr; 408 } 409 410 auto elementType = 411 convertScalarType(targetEnv, options, scalarType, storageClass); 412 if (!elementType) 413 return nullptr; 414 if (elementType != type.getElementType()) { 415 LLVM_DEBUG(llvm::dbgs() 416 << type << " illegal: complex type emulation unsupported\n"); 417 return nullptr; 418 } 419 420 return VectorType::get(2, elementType); 421 } 422 423 /// Converts a tensor `type` to a suitable type under the given `targetEnv`. 424 /// 425 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can 426 /// create composite constants with OpConstantComposite to embed relative large 427 /// constant values and use OpCompositeExtract and OpCompositeInsert to 428 /// manipulate, like what we do for vectors. 429 static Type convertTensorType(const spirv::TargetEnv &targetEnv, 430 const SPIRVConversionOptions &options, 431 TensorType type) { 432 // TODO: Handle dynamic shapes. 433 if (!type.hasStaticShape()) { 434 LLVM_DEBUG(llvm::dbgs() 435 << type << " illegal: dynamic shape unimplemented\n"); 436 return nullptr; 437 } 438 439 type = cast<TensorType>(convertIndexElementType(type, options)); 440 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); 441 if (!scalarType) { 442 LLVM_DEBUG(llvm::dbgs() 443 << type << " illegal: cannot convert non-scalar element type\n"); 444 return nullptr; 445 } 446 447 std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType); 448 std::optional<int64_t> tensorSize = getTypeNumBytes(options, type); 449 if (!scalarSize || !tensorSize) { 450 LLVM_DEBUG(llvm::dbgs() 451 << type << " illegal: cannot deduce element count\n"); 452 return nullptr; 453 } 454 455 int64_t arrayElemCount = *tensorSize / *scalarSize; 456 if (arrayElemCount == 0) { 457 LLVM_DEBUG(llvm::dbgs() 458 << type << " illegal: cannot handle zero-element tensors\n"); 459 return nullptr; 460 } 461 462 Type arrayElemType = convertScalarType(targetEnv, options, scalarType); 463 if (!arrayElemType) 464 return nullptr; 465 std::optional<int64_t> arrayElemSize = 466 getTypeNumBytes(options, arrayElemType); 467 if (!arrayElemSize) { 468 LLVM_DEBUG(llvm::dbgs() 469 << type << " illegal: cannot deduce converted element size\n"); 470 return nullptr; 471 } 472 473 return spirv::ArrayType::get(arrayElemType, arrayElemCount); 474 } 475 476 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, 477 const SPIRVConversionOptions &options, 478 MemRefType type, 479 spirv::StorageClass storageClass) { 480 unsigned numBoolBits = options.boolNumBits; 481 if (numBoolBits != 8) { 482 LLVM_DEBUG(llvm::dbgs() 483 << "using non-8-bit storage for bool types unimplemented"); 484 return nullptr; 485 } 486 auto elementType = dyn_cast<spirv::ScalarType>( 487 IntegerType::get(type.getContext(), numBoolBits)); 488 if (!elementType) 489 return nullptr; 490 Type arrayElemType = 491 convertScalarType(targetEnv, options, elementType, storageClass); 492 if (!arrayElemType) 493 return nullptr; 494 std::optional<int64_t> arrayElemSize = 495 getTypeNumBytes(options, arrayElemType); 496 if (!arrayElemSize) { 497 LLVM_DEBUG(llvm::dbgs() 498 << type << " illegal: cannot deduce converted element size\n"); 499 return nullptr; 500 } 501 502 if (!type.hasStaticShape()) { 503 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing 504 // to the element. 505 if (targetEnv.allows(spirv::Capability::Kernel)) 506 return spirv::PointerType::get(arrayElemType, storageClass); 507 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; 508 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); 509 // For Vulkan we need extra wrapping struct and array to satisfy interface 510 // needs. 511 return wrapInStructAndGetPointer(arrayType, storageClass); 512 } 513 514 if (type.getNumElements() == 0) { 515 LLVM_DEBUG(llvm::dbgs() 516 << type << " illegal: zero-element memrefs are not supported\n"); 517 return nullptr; 518 } 519 520 int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8); 521 int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize); 522 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; 523 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); 524 if (targetEnv.allows(spirv::Capability::Kernel)) 525 return spirv::PointerType::get(arrayType, storageClass); 526 return wrapInStructAndGetPointer(arrayType, storageClass); 527 } 528 529 static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv, 530 const SPIRVConversionOptions &options, 531 MemRefType type, 532 spirv::StorageClass storageClass) { 533 IntegerType elementType = cast<IntegerType>(type.getElementType()); 534 Type arrayElemType = convertSubByteIntegerType(options, elementType); 535 if (!arrayElemType) 536 return nullptr; 537 int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType); 538 539 if (!type.hasStaticShape()) { 540 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing 541 // to the element. 542 if (targetEnv.allows(spirv::Capability::Kernel)) 543 return spirv::PointerType::get(arrayElemType, storageClass); 544 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0; 545 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); 546 // For Vulkan we need extra wrapping struct and array to satisfy interface 547 // needs. 548 return wrapInStructAndGetPointer(arrayType, storageClass); 549 } 550 551 if (type.getNumElements() == 0) { 552 LLVM_DEBUG(llvm::dbgs() 553 << type << " illegal: zero-element memrefs are not supported\n"); 554 return nullptr; 555 } 556 557 int64_t memrefSize = 558 llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8); 559 int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize); 560 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0; 561 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); 562 if (targetEnv.allows(spirv::Capability::Kernel)) 563 return spirv::PointerType::get(arrayType, storageClass); 564 return wrapInStructAndGetPointer(arrayType, storageClass); 565 } 566 567 static Type convertMemrefType(const spirv::TargetEnv &targetEnv, 568 const SPIRVConversionOptions &options, 569 MemRefType type) { 570 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); 571 if (!attr) { 572 LLVM_DEBUG( 573 llvm::dbgs() 574 << type 575 << " illegal: expected memory space to be a SPIR-V storage class " 576 "attribute; please use MemorySpaceToStorageClassConverter to map " 577 "numeric memory spaces beforehand\n"); 578 return nullptr; 579 } 580 spirv::StorageClass storageClass = attr.getValue(); 581 582 if (isa<IntegerType>(type.getElementType())) { 583 if (type.getElementTypeBitWidth() == 1) 584 return convertBoolMemrefType(targetEnv, options, type, storageClass); 585 if (type.getElementTypeBitWidth() < 8) 586 return convertSubByteMemrefType(targetEnv, options, type, storageClass); 587 } 588 589 Type arrayElemType; 590 Type elementType = type.getElementType(); 591 if (auto vecType = dyn_cast<VectorType>(elementType)) { 592 arrayElemType = 593 convertVectorType(targetEnv, options, vecType, storageClass); 594 } else if (auto complexType = dyn_cast<ComplexType>(elementType)) { 595 arrayElemType = 596 convertComplexType(targetEnv, options, complexType, storageClass); 597 } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) { 598 arrayElemType = 599 convertScalarType(targetEnv, options, scalarType, storageClass); 600 } else if (auto indexType = dyn_cast<IndexType>(elementType)) { 601 type = cast<MemRefType>(convertIndexElementType(type, options)); 602 arrayElemType = type.getElementType(); 603 } else { 604 LLVM_DEBUG( 605 llvm::dbgs() 606 << type 607 << " unhandled: can only convert scalar or vector element type\n"); 608 return nullptr; 609 } 610 if (!arrayElemType) 611 return nullptr; 612 613 std::optional<int64_t> arrayElemSize = 614 getTypeNumBytes(options, arrayElemType); 615 if (!arrayElemSize) { 616 LLVM_DEBUG(llvm::dbgs() 617 << type << " illegal: cannot deduce converted element size\n"); 618 return nullptr; 619 } 620 621 if (!type.hasStaticShape()) { 622 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing 623 // to the element. 624 if (targetEnv.allows(spirv::Capability::Kernel)) 625 return spirv::PointerType::get(arrayElemType, storageClass); 626 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; 627 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); 628 // For Vulkan we need extra wrapping struct and array to satisfy interface 629 // needs. 630 return wrapInStructAndGetPointer(arrayType, storageClass); 631 } 632 633 std::optional<int64_t> memrefSize = getTypeNumBytes(options, type); 634 if (!memrefSize) { 635 LLVM_DEBUG(llvm::dbgs() 636 << type << " illegal: cannot deduce element count\n"); 637 return nullptr; 638 } 639 640 if (*memrefSize == 0) { 641 LLVM_DEBUG(llvm::dbgs() 642 << type << " illegal: zero-element memrefs are not supported\n"); 643 return nullptr; 644 } 645 646 int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize); 647 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; 648 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); 649 if (targetEnv.allows(spirv::Capability::Kernel)) 650 return spirv::PointerType::get(arrayType, storageClass); 651 return wrapInStructAndGetPointer(arrayType, storageClass); 652 } 653 654 //===----------------------------------------------------------------------===// 655 // Type casting materialization 656 //===----------------------------------------------------------------------===// 657 658 /// Converts the given `inputs` to the original source `type` considering the 659 /// `targetEnv`'s capabilities. 660 /// 661 /// This function is meant to be used for source materialization in type 662 /// converters. When the type converter needs to materialize a cast op back 663 /// to some original source type, we need to check whether the original source 664 /// type is supported in the target environment. If so, we can insert legal 665 /// SPIR-V cast ops accordingly. 666 /// 667 /// Note that in SPIR-V the capabilities for storage and compute are separate. 668 /// This function is meant to handle the **compute** side; so it does not 669 /// involve storage classes in its logic. The storage side is expected to be 670 /// handled by MemRef conversion logic. 671 static Value castToSourceType(const spirv::TargetEnv &targetEnv, 672 OpBuilder &builder, Type type, ValueRange inputs, 673 Location loc) { 674 // We can only cast one value in SPIR-V. 675 if (inputs.size() != 1) { 676 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 677 return castOp.getResult(0); 678 } 679 Value input = inputs.front(); 680 681 // Only support integer types for now. Floating point types to be implemented. 682 if (!isa<IntegerType>(type)) { 683 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 684 return castOp.getResult(0); 685 } 686 auto inputType = cast<IntegerType>(input.getType()); 687 688 auto scalarType = dyn_cast<spirv::ScalarType>(type); 689 if (!scalarType) { 690 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 691 return castOp.getResult(0); 692 } 693 694 // Only support source type with a smaller bitwidth. This would mean we are 695 // truncating to go back so we don't need to worry about the signedness. 696 // For extension, we cannot have enough signal here to decide which op to use. 697 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) { 698 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 699 return castOp.getResult(0); 700 } 701 702 // Boolean values would need to use different ops than normal integer values. 703 if (type.isInteger(1)) { 704 Value one = spirv::ConstantOp::getOne(inputType, loc, builder); 705 return builder.create<spirv::IEqualOp>(loc, input, one); 706 } 707 708 // Check that the source integer type is supported by the environment. 709 SmallVector<ArrayRef<spirv::Extension>, 1> exts; 710 SmallVector<ArrayRef<spirv::Capability>, 2> caps; 711 scalarType.getExtensions(exts); 712 scalarType.getCapabilities(caps); 713 if (failed(checkCapabilityRequirements(type, targetEnv, caps)) || 714 failed(checkExtensionRequirements(type, targetEnv, exts))) { 715 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 716 return castOp.getResult(0); 717 } 718 719 // We've already made sure this is truncating previously, so we don't need to 720 // care about signedness here. Still try to use a corresponding op for better 721 // consistency though. 722 if (type.isSignedInteger()) { 723 return builder.create<spirv::SConvertOp>(loc, type, input); 724 } 725 return builder.create<spirv::UConvertOp>(loc, type, input); 726 } 727 728 //===----------------------------------------------------------------------===// 729 // Builtin Variables 730 //===----------------------------------------------------------------------===// 731 732 static spirv::GlobalVariableOp getBuiltinVariable(Block &body, 733 spirv::BuiltIn builtin) { 734 // Look through all global variables in the given `body` block and check if 735 // there is a spirv.GlobalVariable that has the same `builtin` attribute. 736 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 737 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>( 738 spirv::SPIRVDialect::getAttributeName( 739 spirv::Decoration::BuiltIn))) { 740 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); 741 if (varBuiltIn && *varBuiltIn == builtin) { 742 return varOp; 743 } 744 } 745 } 746 return nullptr; 747 } 748 749 /// Gets name of global variable for a builtin. 750 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix, 751 StringRef suffix) { 752 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str(); 753 } 754 755 /// Gets or inserts a global variable for a builtin within `body` block. 756 static spirv::GlobalVariableOp 757 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, 758 Type integerType, OpBuilder &builder, 759 StringRef prefix, StringRef suffix) { 760 if (auto varOp = getBuiltinVariable(body, builtin)) 761 return varOp; 762 763 OpBuilder::InsertionGuard guard(builder); 764 builder.setInsertionPointToStart(&body); 765 766 spirv::GlobalVariableOp newVarOp; 767 switch (builtin) { 768 case spirv::BuiltIn::NumWorkgroups: 769 case spirv::BuiltIn::WorkgroupSize: 770 case spirv::BuiltIn::WorkgroupId: 771 case spirv::BuiltIn::LocalInvocationId: 772 case spirv::BuiltIn::GlobalInvocationId: { 773 auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType), 774 spirv::StorageClass::Input); 775 std::string name = getBuiltinVarName(builtin, prefix, suffix); 776 newVarOp = 777 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 778 break; 779 } 780 case spirv::BuiltIn::SubgroupId: 781 case spirv::BuiltIn::NumSubgroups: 782 case spirv::BuiltIn::SubgroupSize: { 783 auto ptrType = 784 spirv::PointerType::get(integerType, spirv::StorageClass::Input); 785 std::string name = getBuiltinVarName(builtin, prefix, suffix); 786 newVarOp = 787 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 788 break; 789 } 790 default: 791 emitError(loc, "unimplemented builtin variable generation for ") 792 << stringifyBuiltIn(builtin); 793 } 794 return newVarOp; 795 } 796 797 //===----------------------------------------------------------------------===// 798 // Push constant storage 799 //===----------------------------------------------------------------------===// 800 801 /// Returns the pointer type for the push constant storage containing 802 /// `elementCount` 32-bit integer values. 803 static spirv::PointerType getPushConstantStorageType(unsigned elementCount, 804 Builder &builder, 805 Type indexType) { 806 auto arrayType = spirv::ArrayType::get(indexType, elementCount, 807 /*stride=*/4); 808 auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); 809 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); 810 } 811 812 /// Returns the push constant varible containing `elementCount` 32-bit integer 813 /// values in `body`. Returns null op if such an op does not exit. 814 static spirv::GlobalVariableOp getPushConstantVariable(Block &body, 815 unsigned elementCount) { 816 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 817 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType()); 818 if (!ptrType) 819 continue; 820 821 // Note that Vulkan requires "There must be no more than one push constant 822 // block statically used per shader entry point." So we should always reuse 823 // the existing one. 824 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { 825 auto numElements = cast<spirv::ArrayType>( 826 cast<spirv::StructType>(ptrType.getPointeeType()) 827 .getElementType(0)) 828 .getNumElements(); 829 if (numElements == elementCount) 830 return varOp; 831 } 832 } 833 return nullptr; 834 } 835 836 /// Gets or inserts a global variable for push constant storage containing 837 /// `elementCount` 32-bit integer values in `block`. 838 static spirv::GlobalVariableOp 839 getOrInsertPushConstantVariable(Location loc, Block &block, 840 unsigned elementCount, OpBuilder &b, 841 Type indexType) { 842 if (auto varOp = getPushConstantVariable(block, elementCount)) 843 return varOp; 844 845 auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); 846 auto type = getPushConstantStorageType(elementCount, builder, indexType); 847 const char *name = "__push_constant_var__"; 848 return builder.create<spirv::GlobalVariableOp>(loc, type, name, 849 /*initializer=*/nullptr); 850 } 851 852 //===----------------------------------------------------------------------===// 853 // func::FuncOp Conversion Patterns 854 //===----------------------------------------------------------------------===// 855 856 /// A pattern for rewriting function signature to convert arguments of functions 857 /// to be of valid SPIR-V types. 858 struct FuncOpConversion final : OpConversionPattern<func::FuncOp> { 859 using OpConversionPattern<func::FuncOp>::OpConversionPattern; 860 861 LogicalResult 862 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, 863 ConversionPatternRewriter &rewriter) const override { 864 FunctionType fnType = funcOp.getFunctionType(); 865 if (fnType.getNumResults() > 1) 866 return failure(); 867 868 TypeConverter::SignatureConversion signatureConverter( 869 fnType.getNumInputs()); 870 for (const auto &argType : enumerate(fnType.getInputs())) { 871 auto convertedType = getTypeConverter()->convertType(argType.value()); 872 if (!convertedType) 873 return failure(); 874 signatureConverter.addInputs(argType.index(), convertedType); 875 } 876 877 Type resultType; 878 if (fnType.getNumResults() == 1) { 879 resultType = getTypeConverter()->convertType(fnType.getResult(0)); 880 if (!resultType) 881 return failure(); 882 } 883 884 // Create the converted spirv.func op. 885 auto newFuncOp = rewriter.create<spirv::FuncOp>( 886 funcOp.getLoc(), funcOp.getName(), 887 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 888 resultType ? TypeRange(resultType) 889 : TypeRange())); 890 891 // Copy over all attributes other than the function name and type. 892 for (const auto &namedAttr : funcOp->getAttrs()) { 893 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && 894 namedAttr.getName() != SymbolTable::getSymbolAttrName()) 895 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); 896 } 897 898 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 899 newFuncOp.end()); 900 if (failed(rewriter.convertRegionTypes( 901 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) 902 return failure(); 903 rewriter.eraseOp(funcOp); 904 return success(); 905 } 906 }; 907 908 /// A pattern for rewriting function signature to convert vector arguments of 909 /// functions to be of valid types 910 struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> { 911 using OpRewritePattern::OpRewritePattern; 912 913 LogicalResult matchAndRewrite(func::FuncOp funcOp, 914 PatternRewriter &rewriter) const override { 915 FunctionType fnType = funcOp.getFunctionType(); 916 917 // TODO: Handle declarations. 918 if (funcOp.isDeclaration()) { 919 LLVM_DEBUG(llvm::dbgs() 920 << fnType << " illegal: declarations are unsupported\n"); 921 return failure(); 922 } 923 924 // Create a new func op with the original type and copy the function body. 925 auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(), 926 funcOp.getName(), fnType); 927 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 928 newFuncOp.end()); 929 930 Location loc = newFuncOp.getBody().getLoc(); 931 932 Block &entryBlock = newFuncOp.getBlocks().front(); 933 OpBuilder::InsertionGuard guard(rewriter); 934 rewriter.setInsertionPointToStart(&entryBlock); 935 936 OneToNTypeMapping oneToNTypeMapping(fnType.getInputs()); 937 938 // For arguments that are of illegal types and require unrolling. 939 // `unrolledInputNums` stores the indices of arguments that result from 940 // unrolling in the new function signature. `newInputNo` is a counter. 941 SmallVector<size_t> unrolledInputNums; 942 size_t newInputNo = 0; 943 944 // For arguments that are of legal types and do not require unrolling. 945 // `tmpOps` stores a mapping from temporary operations that serve as 946 // placeholders for new arguments that will be added later. These operations 947 // will be erased once the entry block's argument list is updated. 948 llvm::SmallDenseMap<Operation *, size_t> tmpOps; 949 950 // This counts the number of new operations created. 951 size_t newOpCount = 0; 952 953 // Enumerate through the arguments. 954 for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) { 955 // Check whether the argument is of vector type. 956 auto origVecType = dyn_cast<VectorType>(origType); 957 if (!origVecType) { 958 // We need a placeholder for the old argument that will be erased later. 959 Value result = rewriter.create<arith::ConstantOp>( 960 loc, origType, rewriter.getZeroAttr(origType)); 961 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); 962 tmpOps.insert({result.getDefiningOp(), newInputNo}); 963 oneToNTypeMapping.addInputs(origInputNo, origType); 964 ++newInputNo; 965 ++newOpCount; 966 continue; 967 } 968 // Check whether the vector needs unrolling. 969 auto targetShape = getTargetShape(origVecType); 970 if (!targetShape) { 971 // We need a placeholder for the old argument that will be erased later. 972 Value result = rewriter.create<arith::ConstantOp>( 973 loc, origType, rewriter.getZeroAttr(origType)); 974 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); 975 tmpOps.insert({result.getDefiningOp(), newInputNo}); 976 oneToNTypeMapping.addInputs(origInputNo, origType); 977 ++newInputNo; 978 ++newOpCount; 979 continue; 980 } 981 VectorType unrolledType = 982 VectorType::get(*targetShape, origVecType.getElementType()); 983 auto originalShape = 984 llvm::to_vector_of<int64_t, 4>(origVecType.getShape()); 985 986 // Prepare the result vector. 987 Value result = rewriter.create<arith::ConstantOp>( 988 loc, origVecType, rewriter.getZeroAttr(origVecType)); 989 ++newOpCount; 990 // Prepare the placeholder for the new arguments that will be added later. 991 Value dummy = rewriter.create<arith::ConstantOp>( 992 loc, unrolledType, rewriter.getZeroAttr(unrolledType)); 993 ++newOpCount; 994 995 // Create the `vector.insert_strided_slice` ops. 996 SmallVector<int64_t> strides(targetShape->size(), 1); 997 SmallVector<Type> newTypes; 998 for (SmallVector<int64_t> offsets : 999 StaticTileOffsetRange(originalShape, *targetShape)) { 1000 result = rewriter.create<vector::InsertStridedSliceOp>( 1001 loc, dummy, result, offsets, strides); 1002 newTypes.push_back(unrolledType); 1003 unrolledInputNums.push_back(newInputNo); 1004 ++newInputNo; 1005 ++newOpCount; 1006 } 1007 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); 1008 oneToNTypeMapping.addInputs(origInputNo, newTypes); 1009 } 1010 1011 // Change the function signature. 1012 auto convertedTypes = oneToNTypeMapping.getConvertedTypes(); 1013 auto newFnType = fnType.clone(convertedTypes, fnType.getResults()); 1014 rewriter.modifyOpInPlace(newFuncOp, 1015 [&] { newFuncOp.setFunctionType(newFnType); }); 1016 1017 // Update the arguments in the entry block. 1018 entryBlock.eraseArguments(0, fnType.getNumInputs()); 1019 SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc()); 1020 entryBlock.addArguments(convertedTypes, locs); 1021 1022 // Replace the placeholder values with the new arguments. We assume there is 1023 // only one block for now. 1024 size_t unrolledInputIdx = 0; 1025 for (auto [count, op] : enumerate(entryBlock.getOperations())) { 1026 // We first look for operands that are placeholders for initially legal 1027 // arguments. 1028 Operation &curOp = op; 1029 for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) { 1030 Operation *operandOp = operandVal.getDefiningOp(); 1031 if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) { 1032 size_t idx = operandIdx; 1033 rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] { 1034 curOp.setOperand(idx, newFuncOp.getArgument(it->second)); 1035 }); 1036 } 1037 } 1038 // Since all newly created operations are in the beginning, reaching the 1039 // end of them means that any later `vector.insert_strided_slice` should 1040 // not be touched. 1041 if (count >= newOpCount) 1042 continue; 1043 if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) { 1044 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx]; 1045 rewriter.modifyOpInPlace(&curOp, [&] { 1046 curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo)); 1047 }); 1048 ++unrolledInputIdx; 1049 } 1050 } 1051 1052 // Erase the original funcOp. The `tmpOps` do not need to be erased since 1053 // they have no uses and will be handled by dead-code elimination. 1054 rewriter.eraseOp(funcOp); 1055 return success(); 1056 } 1057 }; 1058 1059 //===----------------------------------------------------------------------===// 1060 // func::ReturnOp Conversion Patterns 1061 //===----------------------------------------------------------------------===// 1062 1063 /// A pattern for rewriting function signature and the return op to convert 1064 /// vectors to be of valid types. 1065 struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> { 1066 using OpRewritePattern::OpRewritePattern; 1067 1068 LogicalResult matchAndRewrite(func::ReturnOp returnOp, 1069 PatternRewriter &rewriter) const override { 1070 // Check whether the parent funcOp is valid. 1071 auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp()); 1072 if (!funcOp) 1073 return failure(); 1074 1075 FunctionType fnType = funcOp.getFunctionType(); 1076 OneToNTypeMapping oneToNTypeMapping(fnType.getResults()); 1077 Location loc = returnOp.getLoc(); 1078 1079 // For the new return op. 1080 SmallVector<Value> newOperands; 1081 1082 // Enumerate through the results. 1083 for (auto [origResultNo, origType] : enumerate(fnType.getResults())) { 1084 // Check whether the argument is of vector type. 1085 auto origVecType = dyn_cast<VectorType>(origType); 1086 if (!origVecType) { 1087 oneToNTypeMapping.addInputs(origResultNo, origType); 1088 newOperands.push_back(returnOp.getOperand(origResultNo)); 1089 continue; 1090 } 1091 // Check whether the vector needs unrolling. 1092 auto targetShape = getTargetShape(origVecType); 1093 if (!targetShape) { 1094 // The original argument can be used. 1095 oneToNTypeMapping.addInputs(origResultNo, origType); 1096 newOperands.push_back(returnOp.getOperand(origResultNo)); 1097 continue; 1098 } 1099 VectorType unrolledType = 1100 VectorType::get(*targetShape, origVecType.getElementType()); 1101 1102 // Create `vector.extract_strided_slice` ops to form legal vectors from 1103 // the original operand of illegal type. 1104 auto originalShape = 1105 llvm::to_vector_of<int64_t, 4>(origVecType.getShape()); 1106 SmallVector<int64_t> strides(originalShape.size(), 1); 1107 SmallVector<int64_t> extractShape(originalShape.size(), 1); 1108 extractShape.back() = targetShape->back(); 1109 SmallVector<Type> newTypes; 1110 Value returnValue = returnOp.getOperand(origResultNo); 1111 for (SmallVector<int64_t> offsets : 1112 StaticTileOffsetRange(originalShape, *targetShape)) { 1113 Value result = rewriter.create<vector::ExtractStridedSliceOp>( 1114 loc, returnValue, offsets, extractShape, strides); 1115 if (originalShape.size() > 1) { 1116 SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0); 1117 result = 1118 rewriter.create<vector::ExtractOp>(loc, result, extractIndices); 1119 } 1120 newOperands.push_back(result); 1121 newTypes.push_back(unrolledType); 1122 } 1123 oneToNTypeMapping.addInputs(origResultNo, newTypes); 1124 } 1125 1126 // Change the function signature. 1127 auto newFnType = 1128 FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()), 1129 TypeRange(oneToNTypeMapping.getConvertedTypes())); 1130 rewriter.modifyOpInPlace(funcOp, 1131 [&] { funcOp.setFunctionType(newFnType); }); 1132 1133 // Replace the return op using the new operands. This will automatically 1134 // update the entry block as well. 1135 rewriter.replaceOp(returnOp, 1136 rewriter.create<func::ReturnOp>(loc, newOperands)); 1137 1138 return success(); 1139 } 1140 }; 1141 1142 } // namespace 1143 1144 //===----------------------------------------------------------------------===// 1145 // Public function for builtin variables 1146 //===----------------------------------------------------------------------===// 1147 1148 Value mlir::spirv::getBuiltinVariableValue(Operation *op, 1149 spirv::BuiltIn builtin, 1150 Type integerType, OpBuilder &builder, 1151 StringRef prefix, StringRef suffix) { 1152 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 1153 if (!parent) { 1154 op->emitError("expected operation to be within a module-like op"); 1155 return nullptr; 1156 } 1157 1158 spirv::GlobalVariableOp varOp = 1159 getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(), 1160 builtin, integerType, builder, prefix, suffix); 1161 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp); 1162 return builder.create<spirv::LoadOp>(op->getLoc(), ptr); 1163 } 1164 1165 //===----------------------------------------------------------------------===// 1166 // Public function for pushing constant storage 1167 //===----------------------------------------------------------------------===// 1168 1169 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, 1170 unsigned offset, Type integerType, 1171 OpBuilder &builder) { 1172 Location loc = op->getLoc(); 1173 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 1174 if (!parent) { 1175 op->emitError("expected operation to be within a module-like op"); 1176 return nullptr; 1177 } 1178 1179 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( 1180 loc, parent->getRegion(0).front(), elementCount, builder, integerType); 1181 1182 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder); 1183 Value offsetOp = builder.create<spirv::ConstantOp>( 1184 loc, integerType, builder.getI32IntegerAttr(offset)); 1185 auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp); 1186 auto acOp = builder.create<spirv::AccessChainOp>( 1187 loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp})); 1188 return builder.create<spirv::LoadOp>(loc, acOp); 1189 } 1190 1191 //===----------------------------------------------------------------------===// 1192 // Public functions for index calculation 1193 //===----------------------------------------------------------------------===// 1194 1195 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides, 1196 int64_t offset, Type integerType, 1197 Location loc, OpBuilder &builder) { 1198 assert(indices.size() == strides.size() && 1199 "must provide indices for all dimensions"); 1200 1201 // TODO: Consider moving to use affine.apply and patterns converting 1202 // affine.apply to standard ops. This needs converting to SPIR-V passes to be 1203 // broken down into progressive small steps so we can have intermediate steps 1204 // using other dialects. At the moment SPIR-V is the final sink. 1205 1206 Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>( 1207 loc, integerType, IntegerAttr::get(integerType, offset)); 1208 for (const auto &index : llvm::enumerate(indices)) { 1209 Value strideVal = builder.createOrFold<spirv::ConstantOp>( 1210 loc, integerType, 1211 IntegerAttr::get(integerType, strides[index.index()])); 1212 Value update = 1213 builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal); 1214 linearizedIndex = 1215 builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex); 1216 } 1217 return linearizedIndex; 1218 } 1219 1220 Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, 1221 MemRefType baseType, Value basePtr, 1222 ValueRange indices, Location loc, 1223 OpBuilder &builder) { 1224 // Get base and offset of the MemRefType and verify they are static. 1225 1226 int64_t offset; 1227 SmallVector<int64_t, 4> strides; 1228 if (failed(baseType.getStridesAndOffset(strides, offset)) || 1229 llvm::is_contained(strides, ShapedType::kDynamic) || 1230 ShapedType::isDynamic(offset)) { 1231 return nullptr; 1232 } 1233 1234 auto indexType = typeConverter.getIndexType(); 1235 1236 SmallVector<Value, 2> linearizedIndices; 1237 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); 1238 1239 // Add a '0' at the start to index into the struct. 1240 linearizedIndices.push_back(zero); 1241 1242 if (baseType.getRank() == 0) { 1243 linearizedIndices.push_back(zero); 1244 } else { 1245 linearizedIndices.push_back( 1246 linearizeIndex(indices, strides, offset, indexType, loc, builder)); 1247 } 1248 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); 1249 } 1250 1251 Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, 1252 MemRefType baseType, Value basePtr, 1253 ValueRange indices, Location loc, 1254 OpBuilder &builder) { 1255 // Get base and offset of the MemRefType and verify they are static. 1256 1257 int64_t offset; 1258 SmallVector<int64_t, 4> strides; 1259 if (failed(baseType.getStridesAndOffset(strides, offset)) || 1260 llvm::is_contained(strides, ShapedType::kDynamic) || 1261 ShapedType::isDynamic(offset)) { 1262 return nullptr; 1263 } 1264 1265 auto indexType = typeConverter.getIndexType(); 1266 1267 SmallVector<Value, 2> linearizedIndices; 1268 Value linearIndex; 1269 if (baseType.getRank() == 0) { 1270 linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder); 1271 } else { 1272 linearIndex = 1273 linearizeIndex(indices, strides, offset, indexType, loc, builder); 1274 } 1275 Type pointeeType = 1276 cast<spirv::PointerType>(basePtr.getType()).getPointeeType(); 1277 if (isa<spirv::ArrayType>(pointeeType)) { 1278 linearizedIndices.push_back(linearIndex); 1279 return builder.create<spirv::AccessChainOp>(loc, basePtr, 1280 linearizedIndices); 1281 } 1282 return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex, 1283 linearizedIndices); 1284 } 1285 1286 Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter, 1287 MemRefType baseType, Value basePtr, 1288 ValueRange indices, Location loc, 1289 OpBuilder &builder) { 1290 1291 if (typeConverter.allows(spirv::Capability::Kernel)) { 1292 return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc, 1293 builder); 1294 } 1295 1296 return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc, 1297 builder); 1298 } 1299 1300 //===----------------------------------------------------------------------===// 1301 // Public functions for vector unrolling 1302 //===----------------------------------------------------------------------===// 1303 1304 int mlir::spirv::getComputeVectorSize(int64_t size) { 1305 for (int i : {4, 3, 2}) { 1306 if (size % i == 0) 1307 return i; 1308 } 1309 return 1; 1310 } 1311 1312 SmallVector<int64_t> 1313 mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) { 1314 VectorType srcVectorType = op.getSourceVectorType(); 1315 assert(srcVectorType.getRank() == 1); // Guaranteed by semantics 1316 int64_t vectorSize = 1317 mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0)); 1318 return {vectorSize}; 1319 } 1320 1321 SmallVector<int64_t> 1322 mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) { 1323 VectorType vectorType = op.getResultVectorType(); 1324 SmallVector<int64_t> nativeSize(vectorType.getRank(), 1); 1325 nativeSize.back() = 1326 mlir::spirv::getComputeVectorSize(vectorType.getShape().back()); 1327 return nativeSize; 1328 } 1329 1330 std::optional<SmallVector<int64_t>> 1331 mlir::spirv::getNativeVectorShape(Operation *op) { 1332 if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) { 1333 if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) { 1334 SmallVector<int64_t> nativeSize(vecType.getRank(), 1); 1335 nativeSize.back() = 1336 mlir::spirv::getComputeVectorSize(vecType.getShape().back()); 1337 return nativeSize; 1338 } 1339 } 1340 1341 return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op) 1342 .Case<vector::ReductionOp, vector::TransposeOp>( 1343 [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); }) 1344 .Default([](Operation *) { return std::nullopt; }); 1345 } 1346 1347 LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { 1348 MLIRContext *context = op->getContext(); 1349 RewritePatternSet patterns(context); 1350 populateFuncOpVectorRewritePatterns(patterns); 1351 populateReturnOpVectorRewritePatterns(patterns); 1352 // We only want to apply signature conversion once to the existing func ops. 1353 // Without specifying strictMode, the greedy pattern rewriter will keep 1354 // looking for newly created func ops. 1355 GreedyRewriteConfig config; 1356 config.strictMode = GreedyRewriteStrictness::ExistingOps; 1357 return applyPatternsGreedily(op, std::move(patterns), config); 1358 } 1359 1360 LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { 1361 MLIRContext *context = op->getContext(); 1362 1363 // Unroll vectors in function bodies to native vector size. 1364 { 1365 RewritePatternSet patterns(context); 1366 auto options = vector::UnrollVectorOptions().setNativeShapeFn( 1367 [](auto op) { return mlir::spirv::getNativeVectorShape(op); }); 1368 populateVectorUnrollPatterns(patterns, options); 1369 if (failed(applyPatternsGreedily(op, std::move(patterns)))) 1370 return failure(); 1371 } 1372 1373 // Convert transpose ops into extract and insert pairs, in preparation of 1374 // further transformations to canonicalize/cancel. 1375 { 1376 RewritePatternSet patterns(context); 1377 auto options = vector::VectorTransformsOptions().setVectorTransposeLowering( 1378 vector::VectorTransposeLowering::EltWise); 1379 vector::populateVectorTransposeLoweringPatterns(patterns, options); 1380 vector::populateVectorShapeCastLoweringPatterns(patterns); 1381 if (failed(applyPatternsGreedily(op, std::move(patterns)))) 1382 return failure(); 1383 } 1384 1385 // Run canonicalization to cast away leading size-1 dimensions. 1386 { 1387 RewritePatternSet patterns(context); 1388 1389 // We need to pull in casting way leading one dims. 1390 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); 1391 vector::ReductionOp::getCanonicalizationPatterns(patterns, context); 1392 vector::TransposeOp::getCanonicalizationPatterns(patterns, context); 1393 1394 // Decompose different rank insert_strided_slice and n-D 1395 // extract_slided_slice. 1396 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( 1397 patterns); 1398 vector::InsertOp::getCanonicalizationPatterns(patterns, context); 1399 vector::ExtractOp::getCanonicalizationPatterns(patterns, context); 1400 1401 // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean 1402 // them up. 1403 vector::BroadcastOp::getCanonicalizationPatterns(patterns, context); 1404 vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context); 1405 1406 if (failed(applyPatternsGreedily(op, std::move(patterns)))) 1407 return failure(); 1408 } 1409 return success(); 1410 } 1411 1412 //===----------------------------------------------------------------------===// 1413 // SPIR-V TypeConverter 1414 //===----------------------------------------------------------------------===// 1415 1416 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, 1417 const SPIRVConversionOptions &options) 1418 : targetEnv(targetAttr), options(options) { 1419 // Add conversions. The order matters here: later ones will be tried earlier. 1420 1421 // Allow all SPIR-V dialect specific types. This assumes all builtin types 1422 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) 1423 // were tried before. 1424 // 1425 // TODO: This assumes that the SPIR-V types are valid to use in the given 1426 // target environment, which should be the case if the whole pipeline is 1427 // driven by the same target environment. Still, we probably still want to 1428 // validate and convert to be safe. 1429 addConversion([](spirv::SPIRVType type) { return type; }); 1430 1431 addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); 1432 1433 addConversion([this](IntegerType intType) -> std::optional<Type> { 1434 if (auto scalarType = dyn_cast<spirv::ScalarType>(intType)) 1435 return convertScalarType(this->targetEnv, this->options, scalarType); 1436 if (intType.getWidth() < 8) 1437 return convertSubByteIntegerType(this->options, intType); 1438 return Type(); 1439 }); 1440 1441 addConversion([this](FloatType floatType) -> std::optional<Type> { 1442 if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType)) 1443 return convertScalarType(this->targetEnv, this->options, scalarType); 1444 return Type(); 1445 }); 1446 1447 addConversion([this](ComplexType complexType) { 1448 return convertComplexType(this->targetEnv, this->options, complexType); 1449 }); 1450 1451 addConversion([this](VectorType vectorType) { 1452 return convertVectorType(this->targetEnv, this->options, vectorType); 1453 }); 1454 1455 addConversion([this](TensorType tensorType) { 1456 return convertTensorType(this->targetEnv, this->options, tensorType); 1457 }); 1458 1459 addConversion([this](MemRefType memRefType) { 1460 return convertMemrefType(this->targetEnv, this->options, memRefType); 1461 }); 1462 1463 // Register some last line of defense casting logic. 1464 addSourceMaterialization( 1465 [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1466 return castToSourceType(this->targetEnv, builder, type, inputs, loc); 1467 }); 1468 addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, 1469 Location loc) { 1470 auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 1471 return cast.getResult(0); 1472 }); 1473 } 1474 1475 Type SPIRVTypeConverter::getIndexType() const { 1476 return ::getIndexType(getContext(), options); 1477 } 1478 1479 MLIRContext *SPIRVTypeConverter::getContext() const { 1480 return targetEnv.getAttr().getContext(); 1481 } 1482 1483 bool SPIRVTypeConverter::allows(spirv::Capability capability) const { 1484 return targetEnv.allows(capability); 1485 } 1486 1487 //===----------------------------------------------------------------------===// 1488 // SPIR-V ConversionTarget 1489 //===----------------------------------------------------------------------===// 1490 1491 std::unique_ptr<SPIRVConversionTarget> 1492 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { 1493 std::unique_ptr<SPIRVConversionTarget> target( 1494 // std::make_unique does not work here because the constructor is private. 1495 new SPIRVConversionTarget(targetAttr)); 1496 SPIRVConversionTarget *targetPtr = target.get(); 1497 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>( 1498 // We need to capture the raw pointer here because it is stable: 1499 // target will be destroyed once this function is returned. 1500 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); 1501 return target; 1502 } 1503 1504 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr) 1505 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} 1506 1507 bool SPIRVConversionTarget::isLegalOp(Operation *op) { 1508 // Make sure this op is available at the given version. Ops not implementing 1509 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all 1510 // SPIR-V versions. 1511 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) { 1512 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion(); 1513 if (minVersion && *minVersion > this->targetEnv.getVersion()) { 1514 LLVM_DEBUG(llvm::dbgs() 1515 << op->getName() << " illegal: requiring min version " 1516 << spirv::stringifyVersion(*minVersion) << "\n"); 1517 return false; 1518 } 1519 } 1520 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) { 1521 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion(); 1522 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) { 1523 LLVM_DEBUG(llvm::dbgs() 1524 << op->getName() << " illegal: requiring max version " 1525 << spirv::stringifyVersion(*maxVersion) << "\n"); 1526 return false; 1527 } 1528 } 1529 1530 // Make sure this op's required extensions are allowed to use. Ops not 1531 // implementing QueryExtensionInterface do not require extensions to be 1532 // available. 1533 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 1534 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 1535 extensions.getExtensions()))) 1536 return false; 1537 1538 // Make sure this op's required extensions are allowed to use. Ops not 1539 // implementing QueryCapabilityInterface do not require capabilities to be 1540 // available. 1541 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 1542 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 1543 capabilities.getCapabilities()))) 1544 return false; 1545 1546 SmallVector<Type, 4> valueTypes; 1547 valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 1548 valueTypes.append(op->result_type_begin(), op->result_type_end()); 1549 1550 // Ensure that all types have been converted to SPIRV types. 1551 if (llvm::any_of(valueTypes, 1552 [](Type t) { return !isa<spirv::SPIRVType>(t); })) 1553 return false; 1554 1555 // Special treatment for global variables, whose type requirements are 1556 // conveyed by type attributes. 1557 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 1558 valueTypes.push_back(globalVar.getType()); 1559 1560 // Make sure the op's operands/results use types that are allowed by the 1561 // target environment. 1562 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 1563 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 1564 for (Type valueType : valueTypes) { 1565 typeExtensions.clear(); 1566 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions); 1567 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 1568 typeExtensions))) 1569 return false; 1570 1571 typeCapabilities.clear(); 1572 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities); 1573 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 1574 typeCapabilities))) 1575 return false; 1576 } 1577 1578 return true; 1579 } 1580 1581 //===----------------------------------------------------------------------===// 1582 // Public functions for populating patterns 1583 //===----------------------------------------------------------------------===// 1584 1585 void mlir::populateBuiltinFuncToSPIRVPatterns( 1586 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 1587 patterns.add<FuncOpConversion>(typeConverter, patterns.getContext()); 1588 } 1589 1590 void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) { 1591 patterns.add<FuncOpVectorUnroll>(patterns.getContext()); 1592 } 1593 1594 void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) { 1595 patterns.add<ReturnOpVectorUnroll>(patterns.getContext()); 1596 } 1597